konfai 1.1.5__py3-none-any.whl → 1.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

@@ -57,7 +57,7 @@ class DataAugmentationsList():
57
57
 
58
58
  def load(self, key: str, datasets: list[Dataset]):
59
59
  for augmentation, prob in self.dataAugmentationsLoader.items():
60
- module, name = _getModule(augmentation, "data.augmentation")
60
+ module, name = _getModule(augmentation, "konfai.data.augmentation")
61
61
  dataAugmentation: DataAugmentation = config("{}.Dataset.augmentations.{}.dataAugmentations.{}".format(KONFAI_ROOT(), key, augmentation))(getattr(importlib.import_module(module), name))(config = None)
62
62
  dataAugmentation.load(prob.prob)
63
63
  dataAugmentation.setDatasets(datasets)
@@ -264,7 +264,6 @@ class LumaFlip(ColorTransform):
264
264
 
265
265
  class HUE(ColorTransform):
266
266
 
267
- @config("HUE")
268
267
  def __init__(self, hue_max: float, groups: Union[list[str], None] = None) -> None:
269
268
  super().__init__(groups)
270
269
  self.hue_max = hue_max
@@ -24,11 +24,11 @@ from konfai.data.augmentation import DataAugmentationsList
24
24
  class GroupTransform:
25
25
 
26
26
  @config()
27
- def __init__(self, pre_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()},
28
- post_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()},
27
+ def __init__(self, transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()},
28
+ patch_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()},
29
29
  isInput: bool = True) -> None:
30
- self._pre_transforms = pre_transforms
31
- self._post_transforms = post_transforms
30
+ self._pre_transforms = transforms
31
+ self._post_transforms = patch_transforms
32
32
  self.pre_transforms : list[Transform] = []
33
33
  self.post_transforms : list[Transform] = []
34
34
  self.isInput = isInput
@@ -37,7 +37,7 @@ class GroupTransform:
37
37
  if self._pre_transforms is not None:
38
38
  if isinstance(self._pre_transforms, dict):
39
39
  for classpath, transform in self._pre_transforms.items():
40
- transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.pre_transforms".format(KONFAI_ROOT(), group_src, group_dest))
40
+ transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.transforms".format(KONFAI_ROOT(), group_src, group_dest))
41
41
  transform.setDatasets(datasets)
42
42
  self.pre_transforms.append(transform)
43
43
  else:
@@ -48,7 +48,7 @@ class GroupTransform:
48
48
  if self._post_transforms is not None:
49
49
  if isinstance(self._post_transforms, dict):
50
50
  for classpath, transform in self._post_transforms.items():
51
- transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.post_transforms".format(KONFAI_ROOT(), group_src, group_dest))
51
+ transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.patch_transforms".format(KONFAI_ROOT(), group_src, group_dest))
52
52
  transform.setDatasets(datasets)
53
53
  self.post_transforms.append(transform)
54
54
  else:
@@ -65,9 +65,8 @@ class GroupTransform:
65
65
  class GroupTransformMetric(GroupTransform):
66
66
 
67
67
  @config()
68
- def __init__(self, pre_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()},
69
- post_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()}):
70
- super().__init__(pre_transforms, post_transforms)
68
+ def __init__(self, transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()}):
69
+ super().__init__(transforms, None)
71
70
 
72
71
  class Group(dict[str, GroupTransform]):
73
72
 
@@ -184,7 +183,7 @@ class DatasetIter(data.Dataset):
184
183
  data = {}
185
184
  x, a, p = self.map[index]
186
185
  if x not in self._index_cache:
187
- if x not in self._index_cache and len(self._index_cache) >= self.buffer_size and not self.use_cache:
186
+ if len(self._index_cache) >= self.buffer_size and not self.use_cache:
188
187
  self._unloadData(self._index_cache[0])
189
188
  self._loadData(x)
190
189
 
@@ -196,10 +195,9 @@ class DatasetIter(data.Dataset):
196
195
 
197
196
  class Subset():
198
197
 
199
- def __init__(self, subset: Union[str, list[int], list[str], None] = None, shuffle: bool = True, filter: str = None) -> None:
198
+ def __init__(self, subset: Union[str, list[int], list[str], None] = None, shuffle: bool = True) -> None:
200
199
  self.subset = subset
201
200
  self.shuffle = shuffle
202
- self.filter: list[tuple[str]] = [tuple(a.split(":")) for a in filter.split(";")] if filter is not None else None
203
201
 
204
202
  def __call__(self, names: list[str], infos: list[dict[str, tuple[np.ndarray, Attribute]]]) -> set[str]:
205
203
  inter_name = set(names[0])
@@ -207,16 +205,7 @@ class Subset():
207
205
  inter_name = inter_name.intersection(set(n))
208
206
  names = sorted(list(inter_name))
209
207
 
210
- names_filtred = []
211
- if self.filter is not None:
212
- for name in names:
213
- for info in infos:
214
- if all([info[name][1].isInfo(f[0], f[1]) for f in self.filter]):
215
- names_filtred.append(name)
216
- continue
217
- else:
218
- names_filtred = names
219
- size = len(names_filtred)
208
+ size = len(names)
220
209
  index = []
221
210
  if self.subset is None:
222
211
  index = list(range(0, size))
@@ -230,7 +219,7 @@ class Subset():
230
219
  for name in f:
231
220
  train_names.append(name.strip())
232
221
  index = []
233
- for i, name in enumerate(names_filtred):
222
+ for i, name in enumerate(names):
234
223
  if name in train_names:
235
224
  index.append(i)
236
225
  elif self.subset.startswith("~") and os.path.exists(self.subset[1:]):
@@ -239,7 +228,7 @@ class Subset():
239
228
  for name in f:
240
229
  exclude_names.append(name.strip())
241
230
  index = []
242
- for i, name in enumerate(names_filtred):
231
+ for i, name in enumerate(names):
243
232
  if name not in exclude_names:
244
233
  index.append(i)
245
234
 
@@ -252,26 +241,27 @@ class Subset():
252
241
  index = self.subset
253
242
  if isinstance(self.subset[0], str):
254
243
  index = []
255
- for i, name in enumerate(names_filtred):
244
+ for i, name in enumerate(names):
256
245
  if name in self.subset:
257
246
  index.append(i)
258
247
  if self.shuffle:
259
248
  index = random.sample(index, len(index))
260
- return set([names_filtred[i] for i in index])
249
+ return set([names[i] for i in index])
261
250
 
262
251
  def __str__(self):
263
- return "Subset : " + str(self.subset) + " shuffle : "+ str(self.shuffle) + " filter : "+ str(self.filter)
252
+ return "Subset : " + str(self.subset) + " shuffle : "+ str(self.shuffle)
253
+
264
254
  class TrainSubset(Subset):
265
255
 
266
256
  @config()
267
- def __init__(self, subset: Union[str, list[int], list[str], None] = None, shuffle: bool = True, filter: str = None) -> None:
268
- super().__init__(subset, shuffle, filter)
257
+ def __init__(self, subset: Union[str, list[int], list[str], None] = None, shuffle: bool = True) -> None:
258
+ super().__init__(subset, shuffle)
269
259
 
270
260
  class PredictionSubset(Subset):
271
261
 
272
262
  @config()
273
- def __init__(self, subset: Union[str, list[int], list[str], None] = None, filter: str = None) -> None:
274
- super().__init__(subset, False, filter)
263
+ def __init__(self, subset: Union[str, list[int], list[str], None] = None) -> None:
264
+ super().__init__(subset, False)
275
265
 
276
266
  class Data(ABC):
277
267
 
@@ -280,7 +270,6 @@ class Data(ABC):
280
270
  patch : Union[DatasetPatch, None],
281
271
  use_cache : bool,
282
272
  subset : Subset,
283
- num_workers : int,
284
273
  batch_size : int,
285
274
  validation: Union[float, str, list[int], list[str], None] = None,
286
275
  inlineAugmentations: bool = False,
@@ -293,7 +282,7 @@ class Data(ABC):
293
282
  self.dataAugmentationsList = dataAugmentationsList
294
283
  self.batch_size = batch_size
295
284
  self.dataSet_args = dict(groups_src=self.groups_src, inlineAugmentations=inlineAugmentations, dataAugmentationsList = list(self.dataAugmentationsList.values()), use_cache = use_cache, buffer_size=batch_size+1, patch_size=self.patch.patch_size if self.patch is not None else None, overlap=self.patch.overlap if self.patch is not None else None)
296
- self.dataLoader_args = dict(num_workers=num_workers if use_cache else 0, pin_memory=True)
285
+ self.dataLoader_args = dict(num_workers=int(os.environ["KONFAI_WORKERS"]), pin_memory=True)
297
286
  self.data : list[list[dict[str, list[DatasetManager]]], dict[str, list[DatasetManager]]] = []
298
287
  self.map : list[list[list[tuple[int, int, int]]], list[tuple[int, int, int]]] = []
299
288
  self.datasets: dict[str, Dataset] = {}
@@ -547,10 +536,9 @@ class DataTrain(Data):
547
536
  patch : Union[DatasetPatch, None] = DatasetPatch(),
548
537
  use_cache : bool = True,
549
538
  subset : Union[TrainSubset, dict[str, TrainSubset]] = TrainSubset(),
550
- num_workers : int = 4,
551
539
  batch_size : int = 1,
552
540
  validation : Union[float, str, list[int], list[str]] = 0.2) -> None:
553
- super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, num_workers, batch_size, validation, inlineAugmentations, augmentations if augmentations else {})
541
+ super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, batch_size, validation, inlineAugmentations, augmentations if augmentations else {})
554
542
 
555
543
  class DataPrediction(Data):
556
544
 
@@ -560,10 +548,9 @@ class DataPrediction(Data):
560
548
  augmentations : Union[dict[str, DataAugmentationsList], None] = {"DataAugmentation_0" : DataAugmentationsList()},
561
549
  patch : Union[DatasetPatch, None] = DatasetPatch(),
562
550
  subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
563
- num_workers : int = 4,
564
551
  batch_size : int = 1) -> None:
565
552
 
566
- super().__init__(dataset_filenames, groups_src, patch, False, subset, num_workers, batch_size, dataAugmentationsList=augmentations if augmentations else {})
553
+ super().__init__(dataset_filenames, groups_src, patch, False, subset, batch_size, dataAugmentationsList=augmentations if augmentations else {})
567
554
 
568
555
  class DataMetric(Data):
569
556
 
@@ -571,7 +558,6 @@ class DataMetric(Data):
571
558
  def __init__(self, dataset_filenames : list[str] = ["default:./Dataset"],
572
559
  groups_src : dict[str, GroupMetric] = {"default" : GroupMetric()},
573
560
  subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
574
- validation: Union[str, None] = None,
575
- num_workers : int = 4) -> None:
561
+ validation: Union[str, None] = None) -> None:
576
562
 
577
- super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, num_workers=num_workers, batch_size=1, validation=validation)
563
+ super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, batch_size=1, validation=validation)
konfai/data/patching.py CHANGED
@@ -81,7 +81,6 @@ class PathCombine(ABC):
81
81
 
82
82
  class Mean(PathCombine):
83
83
 
84
- @config("Mean")
85
84
  def __init__(self) -> None:
86
85
  super().__init__()
87
86
 
@@ -90,7 +89,6 @@ class Mean(PathCombine):
90
89
 
91
90
  class Cosinus(PathCombine):
92
91
 
93
- @config("Cosinus")
94
92
  def __init__(self) -> None:
95
93
  super().__init__()
96
94
 
@@ -144,7 +142,7 @@ class Accumulator():
144
142
 
145
143
  class Patch(ABC):
146
144
 
147
- def __init__(self, patch_size: list[int], overlap: Union[int, None], path_mask: Union[str, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
145
+ def __init__(self, patch_size: list[int], overlap: Union[int, None], padValue: float = 0, extend_slice: int = 0) -> None:
148
146
  self.patch_size = patch_size
149
147
  self.overlap = overlap
150
148
  if isinstance(self.overlap, int):
@@ -152,15 +150,8 @@ class Patch(ABC):
152
150
  self.overlap = None
153
151
  self._patch_slices : dict[int, list[tuple[slice]]] = {}
154
152
  self._nb_patch_per_dim: dict[int, list[tuple[int, bool]]] = {}
155
- self.path_mask = path_mask
156
- self.mask = None
157
153
  self.padValue = padValue
158
154
  self.extend_slice = extend_slice
159
- if self.path_mask is not None:
160
- if os.path.exists(self.path_mask):
161
- self.mask = torch.tensor(sitk.GetArrayFromImage(sitk.ReadImage(self.path_mask)))
162
- else:
163
- raise NameError('Mask file not found')
164
155
 
165
156
  def load(self, shape : dict[int, list[int]], a: int = 0) -> None:
166
157
  self._patch_slices[a], self._nb_patch_per_dim[a] = get_patch_slices_from_shape(self.patch_size, shape, self.overlap)
@@ -199,9 +190,6 @@ class Patch(ABC):
199
190
  padding.append(p)
200
191
 
201
192
  data_sliced = F.pad(data_sliced, tuple(padding), "constant", 0 if data_sliced.dtype == torch.uint8 and self.padValue < 0 else self.padValue)
202
- if self.mask is not None:
203
- outside = torch.ones_like(data_sliced)*(0 if data_sliced.dtype == torch.uint8 and self.padValue < 0 else self.padValue)
204
- data_sliced = torch.where(self.mask == 0, outside, data_sliced)
205
193
 
206
194
  for d in [i for i, v in enumerate(reversed(self.patch_size)) if v == 1]:
207
195
  data_sliced = torch.squeeze(data_sliced, dim = len(data_sliced.shape)-d-1)
@@ -213,14 +201,14 @@ class Patch(ABC):
213
201
  class DatasetPatch(Patch):
214
202
 
215
203
  @config("Patch")
216
- def __init__(self, patch_size : list[int] = [128, 128, 128], overlap : Union[int, None] = None, mask: Union[str, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
217
- super().__init__(patch_size, overlap, mask, padValue, extend_slice)
204
+ def __init__(self, patch_size : list[int] = [128, 128, 128], overlap : Union[int, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
205
+ super().__init__(patch_size, overlap, padValue, extend_slice)
218
206
 
219
207
  class ModelPatch(Patch):
220
208
 
221
209
  @config("Patch")
222
- def __init__(self, patch_size : list[int] = [128, 128, 128], overlap : Union[int, None] = None, patchCombine: Union[str, None] = None, mask: Union[str, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
223
- super().__init__(patch_size, overlap, mask, padValue, extend_slice)
210
+ def __init__(self, patch_size : list[int] = [128, 128, 128], overlap : Union[int, None] = None, patchCombine: Union[str, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
211
+ super().__init__(patch_size, overlap, padValue, extend_slice)
224
212
  self.patchCombine = patchCombine
225
213
 
226
214
  def disassemble(self, *dataList: torch.Tensor) -> Iterator[list[torch.Tensor]]:
@@ -247,7 +235,7 @@ class DatasetManager():
247
235
  for transformFunction in pre_transforms:
248
236
  _shape = transformFunction.transformShape(_shape, cache_attribute)
249
237
 
250
- self.patch = DatasetPatch(patch_size=patch.patch_size, overlap=patch.overlap, mask=patch.path_mask, padValue=patch.padValue, extend_slice=patch.extend_slice) if patch else DatasetPatch(_shape)
238
+ self.patch = DatasetPatch(patch_size=patch.patch_size, overlap=patch.overlap, padValue=patch.padValue, extend_slice=patch.extend_slice) if patch else DatasetPatch(_shape)
251
239
  self.patch.load(_shape, 0)
252
240
  self.shape = _shape
253
241
  self.dataAugmentationsList = dataAugmentationsList
@@ -314,7 +302,7 @@ class DatasetManager():
314
302
  dataset.write(self.group_dest, self.name, data.numpy(), self.cache_attributes[0])
315
303
  self.data : list[torch.Tensor] = list()
316
304
  self.data.append(data)
317
-
305
+
318
306
  for i in range(len(self.cache_attributes)-1):
319
307
  self.cache_attributes[i+1].update(self.cache_attributes[0])
320
308
  self.loaded = True
konfai/data/transform.py CHANGED
@@ -36,7 +36,7 @@ class TransformLoader:
36
36
  pass
37
37
 
38
38
  def getTransform(self, classpath : str, DL_args : str) -> Transform:
39
- module, name = _getModule(classpath, "data.transform")
39
+ module, name = _getModule(classpath, "konfai.data.transform")
40
40
  return config("{}.{}".format(DL_args, classpath))(getattr(importlib.import_module(module), name))(config = None)
41
41
 
42
42
  class Clip(Transform):
@@ -213,7 +213,7 @@ class Resample(Transform, ABC):
213
213
 
214
214
  class ResampleToResolution(Resample):
215
215
 
216
- def __init__(self, spacing : list[Union[float, None]] = [1., 1., 1.]) -> None:
216
+ def __init__(self, spacing : list[float] = [1., 1., 1.]) -> None:
217
217
  self.spacing = torch.tensor([0 if s < 0 else s for s in spacing])
218
218
 
219
219
  def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
@@ -244,34 +244,36 @@ class ResampleToResolution(Resample):
244
244
  cache_attribute["Size"] = np.asarray(size)
245
245
  return self._resample(input, size)
246
246
 
247
- class ResampleToSize(Resample):
247
+ class ResampleToShape(Resample):
248
248
 
249
- def __init__(self, size : list[int] = [100,512,512]) -> None:
250
- self.size = size
249
+ def __init__(self, shape : list[float] = [100,256,256]) -> None:
250
+ self.shape = torch.tensor([0 if s < 0 else s for s in shape])
251
251
 
252
252
  def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
253
+ print(shape)
253
254
  if "Spacing" not in cache_attribute:
254
255
  TransformError("Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
255
256
  "Make sure your input is a image (e.g., .nii, .mha) with proper metadata.")
256
- if len(shape) != len(self.size):
257
+ if len(shape) != len(self.shape):
257
258
  TransformError("Shape and spacing dimensions do not match: shape={shape}, spacing={self.spacing}")
258
- size = self.size
259
- for i, s in enumerate(self.size):
260
- if s == -1:
261
- size[i] = shape[i]
262
- return size
259
+ new_shape = self.shape
260
+ for i, s in enumerate(self.shape):
261
+ if s == 0:
262
+ new_shape[i] = shape[i]
263
+ print(new_shape)
264
+ return new_shape
263
265
 
264
266
  def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
265
- size = self.size
266
- image_size = np.asarray([int(x) for x in torch.tensor(input.shape[1:])])
267
- for i, s in enumerate(self.size):
268
- if s is None:
269
- size[i] = image_size[i]
267
+ shape = self.shape
268
+ image_shape = torch.tensor([int(x) for x in torch.tensor(input.shape[1:])])
269
+ for i, s in enumerate(self.shape):
270
+ if s == 0:
271
+ shape[i] = image_shape[i]
270
272
  if "Spacing" in cache_attribute:
271
- cache_attribute["Spacing"] = torch.flip(torch.tensor(image_size)/torch.tensor(size)*torch.flip(cache_attribute.get_tensor("Spacing"), dims=[0]), dims=[0])
272
- cache_attribute["Size"] = image_size
273
- cache_attribute["Size"] = size
274
- return self._resample(input, size)
273
+ cache_attribute["Spacing"] = torch.flip(image_shape/shape*torch.flip(cache_attribute.get_tensor("Spacing"), dims=[0]), dims=[0])
274
+ cache_attribute["Size"] = image_shape
275
+ cache_attribute["Size"] = shape
276
+ return self._resample(input, shape)
275
277
 
276
278
  class ResampleTransform(Transform):
277
279
 
@@ -410,7 +412,7 @@ class Gradient(Transform):
410
412
  def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
411
413
  return input
412
414
 
413
- class ArgMax(Transform):
415
+ class Argmax(Transform):
414
416
 
415
417
  def __init__(self, dim: int = 0) -> None:
416
418
  self.dim = dim
konfai/evaluator.py CHANGED
@@ -27,7 +27,7 @@ class CriterionsLoader():
27
27
  def getCriterions(self, output_group : str, target_group : str) -> dict[torch.nn.Module, CriterionsAttr]:
28
28
  criterions = {}
29
29
  for module_classpath, criterionsAttr in self.criterionsLoader.items():
30
- module, name = _getModule(module_classpath, "metric.measure")
30
+ module, name = _getModule(module_classpath, "konfai.metric.measure")
31
31
  criterions[config("{}.metrics.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
32
32
  return criterions
33
33
 
konfai/main.py CHANGED
@@ -9,8 +9,6 @@ import sys
9
9
  sys.path.insert(0, os.getcwd())
10
10
 
11
11
  def main():
12
- import tracemalloc
13
-
14
12
  parser = argparse.ArgumentParser(description="KonfAI", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
15
13
  try:
16
14
  with setup(parser) as distributedObject:
konfai/metric/measure.py CHANGED
@@ -9,7 +9,6 @@ from scipy import linalg
9
9
 
10
10
  import torch.nn.functional as F
11
11
  import os
12
- import tqdm
13
12
 
14
13
  from typing import Callable, Union
15
14
  from functools import partial
@@ -19,10 +18,11 @@ import copy
19
18
  from abc import abstractmethod
20
19
 
21
20
  from konfai.utils.config import config
22
- from konfai.utils.utils import _getModule
21
+ from konfai.utils.utils import _getModule, download_url
23
22
  from konfai.data.patching import ModelPatch
24
23
  from konfai.network.blocks import LatentDistribution
25
24
  from konfai.network.network import ModelLoader, Network
25
+ from tqdm import tqdm
26
26
 
27
27
  modelsRegister = {}
28
28
 
@@ -122,7 +122,7 @@ class LPIPS(MaskedLoss):
122
122
 
123
123
  patchIterator = datasetPatch.disassemble(LPIPS.normalize(x), LPIPS.normalize(y))
124
124
  loss = 0
125
- with tqdm.tqdm(iterable = enumerate(patchIterator), leave=False, total=datasetPatch.getSize(0)) as batch_iter:
125
+ with tqdm(iterable = enumerate(patchIterator), leave=False, total=datasetPatch.getSize(0)) as batch_iter:
126
126
  for i, patch_input in batch_iter:
127
127
  real, fake = LPIPS.preprocessing(patch_input[0]), LPIPS.preprocessing(patch_input[1])
128
128
  loss += loss_fn_alex(real, fake).item()
@@ -146,9 +146,11 @@ class Dice(Criterion):
146
146
  target = self.flatten(target)
147
147
  return (2.*(input * target).sum() + self.smooth)/(input.sum() + target.sum() + self.smooth)
148
148
 
149
-
150
149
  def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
151
- target = targets[0]
150
+ target = F.interpolate(
151
+ targets[0],
152
+ output.shape[2:],
153
+ mode="nearest")
152
154
  if output.shape[1] == 1:
153
155
  output = F.one_hot(output.type(torch.int64), num_classes=int(torch.max(output).item()+1)).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float()
154
156
  target = F.one_hot(target.type(torch.int64), num_classes=output.shape[1]).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float().squeeze(2)
@@ -242,7 +244,7 @@ class PerceptualLoss(Criterion):
242
244
 
243
245
  class Module():
244
246
 
245
- @config(None)
247
+ @config()
246
248
  def __init__(self, losses: dict[str, float] = {"Gram": 1, "torch_nn_L1Loss": 1}) -> None:
247
249
  self.losses = losses
248
250
  self.DL_args = os.environ['KONFAI_CONFIG_PATH'] if "KONFAI_CONFIG_PATH" in os.environ else ""
@@ -250,7 +252,7 @@ class PerceptualLoss(Criterion):
250
252
  def getLoss(self) -> dict[torch.nn.Module, float]:
251
253
  result: dict[torch.nn.Module, float] = {}
252
254
  for loss, l in self.losses.items():
253
- module, name = _getModule(loss, "metric.measure")
255
+ module, name = _getModule(loss, "konfai.metric.measure")
254
256
  result[config(self.DL_args)(getattr(importlib.import_module(module), name))(config=None)] = l
255
257
  return result
256
258
 
@@ -413,16 +415,19 @@ class FocalLoss(Criterion):
413
415
  self.gamma = gamma
414
416
  self.reduction = reduction
415
417
 
416
- def forward(self, output: torch.Tensor, targets: list[torch.Tensor]) -> torch.Tensor:
417
- target = targets[0].long()
418
+ def forward(self, output: torch.Tensor, *targets: list[torch.Tensor]) -> torch.Tensor:
419
+ target = F.interpolate(
420
+ targets[0],
421
+ output.shape[2:],
422
+ mode="nearest").long()
418
423
 
419
424
  logpt = F.log_softmax(output, dim=1)
420
425
  pt = torch.exp(logpt)
421
426
 
422
- logpt = logpt.gather(1, target.unsqueeze(1))
423
- pt = pt.gather(1, target.unsqueeze(1))
427
+ logpt = logpt.gather(1, target)
428
+ pt = pt.gather(1, target)
424
429
 
425
- at = self.alpha[target].unsqueeze(1)
430
+ at = self.alpha.to(target.device)[target].unsqueeze(1)
426
431
  loss = -at * ((1 - pt) ** self.gamma) * logpt
427
432
 
428
433
  if self.reduction == "mean":
@@ -458,13 +463,11 @@ class FID(Criterion):
458
463
  return features
459
464
 
460
465
  def calculate_fid(real_features: np.ndarray, generated_features: np.ndarray) -> float:
461
- # Calculate mean and covariance statistics
462
466
  mu1 = np.mean(real_features, axis=0)
463
467
  sigma1 = np.cov(real_features, rowvar=False)
464
468
  mu2 = np.mean(generated_features, axis=0)
465
469
  sigma2 = np.cov(generated_features, rowvar=False)
466
470
 
467
- # Calculate FID score
468
471
  diff = mu1 - mu2
469
472
  covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
470
473
  if np.iscomplexobj(covmean):
@@ -480,7 +483,6 @@ class FID(Criterion):
480
483
  generated_features = FID.get_features(generated_images, self.inception_model)
481
484
 
482
485
  return FID.calculate_fid(real_features, generated_features)
483
-
484
486
 
485
487
  class MutualInformationLoss(torch.nn.Module):
486
488
  def __init__(self, num_bins: int = 23, sigma_ratio: float = 0.5, smooth_nr: float = 1e-7, smooth_dr: float = 1e-7) -> None:
@@ -498,7 +500,6 @@ class MutualInformationLoss(torch.nn.Module):
498
500
  target_weight, target_probability = self.parzen_windowing_gaussian(target)
499
501
  return pred_weight, pred_probability, target_weight, target_probability
500
502
 
501
-
502
503
  def parzen_windowing_gaussian(self, img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
503
504
  img = torch.clamp(img, 0, 1)
504
505
  img = img.reshape(img.shape[0], -1, 1) # (batch, num_sample, 1)
@@ -512,4 +513,72 @@ class MutualInformationLoss(torch.nn.Module):
512
513
  pab = torch.bmm(wa.permute(0, 2, 1), wb.to(wa)).div(wa.shape[1]) # (batch, num_bins, num_bins)
513
514
  papb = torch.bmm(pa.permute(0, 2, 1), pb.to(pa)) # (batch, num_bins, num_bins)
514
515
  mi = torch.sum(pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr), dim=(1, 2)) # (batch)
515
- return torch.mean(mi).neg() # average over the batch and channel ndims
516
+ return torch.mean(mi).neg() # average over the batch and channel ndims
517
+
518
+ class FOCUS(Criterion): #Feature-Oriented Comparison for Unpaired Synthesis
519
+
520
+ def __init__(self, model_name: str, shape: list[int] = [0,0], in_channels: int = 1, losses: dict[str, list[float]] = {"Gram": 1, "torch_nn_L1Loss": 1}) -> None:
521
+ super().__init__()
522
+ if model_name is None:
523
+ return
524
+ self.shape = shape
525
+ self.in_channels = in_channels
526
+ self.losses: dict[torch.nn.Module, list[float]] = {}
527
+ for loss, weights in losses.items():
528
+ module, name = _getModule(loss, "konfai.metric.measure")
529
+ self.losses[config(os.environ['KONFAI_CONFIG_PATH'])(getattr(importlib.import_module(module), name))(config=None)] = weights
530
+
531
+ self.model_path = download_url(model_name, "https://huggingface.co/VBoussot/impact-torchscript-models/resolve/main/")
532
+ self.model: torch.nn.Module = torch.jit.load(self.model_path)
533
+ self.dim = len(self.shape)
534
+ self.shape = self.shape if all([s > 0 for s in self.shape]) else None
535
+ self.modules_loss: dict[str, dict[torch.nn.Module, float]] = {}
536
+
537
+ try:
538
+ dummy_input = torch.zeros((1, self.in_channels, *(self.shape if self.shape else [224] * self.dim)))
539
+ out = self.model(dummy_input)
540
+ if not isinstance(out, (list, tuple)):
541
+ raise TypeError(f"Expected model output to be a list or tuple, but got {type(out)}.")
542
+ if len(self.weight) != len(out):
543
+ raise ValueError(f"Mismatch between number of weights ({len(self.weight)}) and model outputs ({len(out)}).")
544
+ except Exception as e:
545
+ msg = (
546
+ f"[Model Sanity Check Failed]\n"
547
+ f"Input shape attempted: {dummy_input.shape}\n"
548
+ f"Expected output length: {len(self.weight)}\n"
549
+ f"Error: {type(e).__name__}: {e}"
550
+ )
551
+ raise RuntimeError(msg) from e
552
+ self.model = None
553
+
554
+ def preprocessing(self, input: torch.Tensor) -> torch.Tensor:
555
+ if self.shape is not None and not all([input.shape[-i-1] == size for i, size in enumerate(reversed(self.shape[2:]))]):
556
+ input = torch.nn.functional.interpolate(input, mode=self.mode, size=tuple(self.shape), align_corners=False).type(torch.float32)
557
+ if input.shape[1] != self.in_channels:
558
+ input = input.repeat(tuple([1,3] + [1 for _ in range(self.dim)]))
559
+ return input
560
+
561
+ def _compute(self, output: torch.Tensor, targets: list[torch.Tensor]) -> torch.Tensor:
562
+ loss = torch.zeros((1), requires_grad = True).to(output.device, non_blocking=False).type(torch.float32)
563
+ output = self.preprocessing(output)
564
+ targets = [self.preprocessing(target) for target in targets]
565
+ self.model.to(output.device)
566
+ for zipped_output in zip(self.model(output), *[self.model(target) for target in targets]):
567
+ output_features = zipped_output[0]
568
+ targets_features = zipped_output[1:]
569
+ for target_features, (loss_function, weights) in zip(targets_features, self.losses.items()):
570
+ for output_feature, target_feature, weight in zip(output_features ,target_features, weights):
571
+ loss += weights*loss_function(output_feature, target_feature)
572
+ return loss
573
+
574
+ def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
575
+ if self.model is None:
576
+ self.model = torch.jit.load(self.model_path)
577
+ loss = torch.zeros((1), requires_grad = True).to(output.device, non_blocking=False).type(torch.float32)
578
+ if len(output.shape) == 5 and self.dim == 2:
579
+ for i in range(output.shape[2]):
580
+ loss = loss + self._compute(output[:, :, i, ...], [t[:, :, i, ...] for t in targets])
581
+ loss/=output.shape[2]
582
+ else:
583
+ loss = self._compute(output, targets)
584
+ return loss.to(output)
@@ -2,10 +2,10 @@ import torch
2
2
  import numpy as np
3
3
  from abc import abstractmethod
4
4
  from konfai.utils.config import config
5
+ from functools import partial
5
6
 
6
7
  class Scheduler():
7
8
 
8
- @config("Scheduler")
9
9
  def __init__(self, start_value: float) -> None:
10
10
  self.baseValue = float(start_value)
11
11
  self.it = 0
@@ -19,7 +19,6 @@ class Scheduler():
19
19
 
20
20
  class Constant(Scheduler):
21
21
 
22
- @config("Constant")
23
22
  def __init__(self, value: float = 1):
24
23
  super().__init__(value)
25
24
 
@@ -28,8 +27,7 @@ class Constant(Scheduler):
28
27
 
29
28
  class CosineAnnealing(Scheduler):
30
29
 
31
- @config("CosineAnnealing")
32
- def __init__(self, start_value: float, eta_min: float = 0.00001, T_max: int = 100):
30
+ def __init__(self, start_value: float = 1, eta_min: float = 0.00001, T_max: int = 100):
33
31
  super().__init__(start_value)
34
32
  self.eta_min = eta_min
35
33
  self.T_max = T_max
@@ -37,13 +35,10 @@ class CosineAnnealing(Scheduler):
37
35
  def get_value(self):
38
36
  return self.eta_min + (self.baseValue - self.eta_min) *(1 + np.cos(self.it * torch.pi / self.T_max)) / 2
39
37
 
40
- class CosineAnnealing(Scheduler):
38
+ class Warmup(torch.optim.lr_scheduler.LambdaLR):
41
39
 
42
- @config("CosineAnnealing")
43
- def __init__(self, start_value: float, eta_min: float = 0.00001, T_max: int = 100):
44
- super().__init__(start_value)
45
- self.eta_min = eta_min
46
- self.T_max = T_max
40
+ def warmup(warmup_steps: int, step: int) -> float:
41
+ return min(1.0, (step+1) / (warmup_steps+1))
47
42
 
48
- def get_value(self):
49
- return self.eta_min + (self.baseValue - self.eta_min) *(1 + np.cos(self.it * torch.pi / self.T_max)) / 2
43
+ def __init__(self, optimizer: torch.optim.Optimizer, warmup_steps: int = 10, last_epoch=-1, verbose="deprecated"):
44
+ super().__init__(optimizer, partial(Warmup.warmup, warmup_steps), last_epoch, verbose)
@@ -157,10 +157,9 @@ class Head(network.ModuleArgsDict):
157
157
 
158
158
  class ConvNeXt(network.Network):
159
159
 
160
- @config("ConvNeXt")
161
160
  def __init__( self,
162
161
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
163
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
162
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
164
163
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
165
164
  patch : ModelPatch = ModelPatch(),
166
165
  dim : int = 3,
@@ -98,10 +98,9 @@ class Head(network.ModuleArgsDict):
98
98
 
99
99
  class ResNet(network.Network):
100
100
 
101
- @config("ResNet")
102
101
  def __init__( self,
103
102
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
104
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
103
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
105
104
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
106
105
  patch : ModelPatch = ModelPatch(),
107
106
  dim : int = 3,