konfai 1.1.6__tar.gz → 1.1.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- {konfai-1.1.6 → konfai-1.1.8}/PKG-INFO +1 -1
- {konfai-1.1.6 → konfai-1.1.8}/konfai/data/augmentation.py +0 -1
- {konfai-1.1.6 → konfai-1.1.8}/konfai/data/data_manager.py +25 -39
- {konfai-1.1.6 → konfai-1.1.8}/konfai/data/patching.py +7 -19
- {konfai-1.1.6 → konfai-1.1.8}/konfai/data/transform.py +1 -2
- {konfai-1.1.6 → konfai-1.1.8}/konfai/metric/measure.py +85 -12
- {konfai-1.1.6 → konfai-1.1.8}/konfai/metric/schedulers.py +2 -16
- {konfai-1.1.6 → konfai-1.1.8}/konfai/models/classification/convNeXt.py +1 -2
- {konfai-1.1.6 → konfai-1.1.8}/konfai/models/classification/resnet.py +1 -2
- {konfai-1.1.6 → konfai-1.1.8}/konfai/models/generation/cStyleGan.py +1 -2
- {konfai-1.1.6 → konfai-1.1.8}/konfai/models/generation/ddpm.py +1 -2
- {konfai-1.1.6 → konfai-1.1.8}/konfai/models/generation/diffusionGan.py +10 -23
- {konfai-1.1.6 → konfai-1.1.8}/konfai/models/generation/gan.py +3 -6
- {konfai-1.1.6 → konfai-1.1.8}/konfai/models/generation/vae.py +0 -2
- {konfai-1.1.6 → konfai-1.1.8}/konfai/models/registration/registration.py +1 -2
- {konfai-1.1.6 → konfai-1.1.8}/konfai/models/representation/representation.py +1 -2
- {konfai-1.1.6 → konfai-1.1.8}/konfai/models/segmentation/NestedUNet.py +10 -9
- {konfai-1.1.6 → konfai-1.1.8}/konfai/models/segmentation/UNet.py +1 -1
- {konfai-1.1.6 → konfai-1.1.8}/konfai/network/blocks.py +8 -6
- {konfai-1.1.6 → konfai-1.1.8}/konfai/network/network.py +35 -49
- {konfai-1.1.6 → konfai-1.1.8}/konfai/predictor.py +39 -59
- {konfai-1.1.6 → konfai-1.1.8}/konfai/utils/utils.py +40 -2
- {konfai-1.1.6 → konfai-1.1.8}/konfai.egg-info/PKG-INFO +1 -1
- {konfai-1.1.6 → konfai-1.1.8}/pyproject.toml +1 -1
- {konfai-1.1.6 → konfai-1.1.8}/LICENSE +0 -0
- {konfai-1.1.6 → konfai-1.1.8}/README.md +0 -0
- {konfai-1.1.6 → konfai-1.1.8}/konfai/__init__.py +0 -0
- {konfai-1.1.6 → konfai-1.1.8}/konfai/data/__init__.py +0 -0
- {konfai-1.1.6 → konfai-1.1.8}/konfai/evaluator.py +0 -0
- {konfai-1.1.6 → konfai-1.1.8}/konfai/main.py +0 -0
- {konfai-1.1.6 → konfai-1.1.8}/konfai/metric/__init__.py +0 -0
- {konfai-1.1.6 → konfai-1.1.8}/konfai/network/__init__.py +0 -0
- {konfai-1.1.6 → konfai-1.1.8}/konfai/trainer.py +0 -0
- {konfai-1.1.6 → konfai-1.1.8}/konfai/utils/ITK.py +0 -0
- {konfai-1.1.6 → konfai-1.1.8}/konfai/utils/__init__.py +0 -0
- {konfai-1.1.6 → konfai-1.1.8}/konfai/utils/config.py +0 -0
- {konfai-1.1.6 → konfai-1.1.8}/konfai/utils/dataset.py +0 -0
- {konfai-1.1.6 → konfai-1.1.8}/konfai/utils/registration.py +0 -0
- {konfai-1.1.6 → konfai-1.1.8}/konfai.egg-info/SOURCES.txt +0 -0
- {konfai-1.1.6 → konfai-1.1.8}/konfai.egg-info/dependency_links.txt +0 -0
- {konfai-1.1.6 → konfai-1.1.8}/konfai.egg-info/entry_points.txt +0 -0
- {konfai-1.1.6 → konfai-1.1.8}/konfai.egg-info/requires.txt +0 -0
- {konfai-1.1.6 → konfai-1.1.8}/konfai.egg-info/top_level.txt +0 -0
- {konfai-1.1.6 → konfai-1.1.8}/setup.cfg +0 -0
- {konfai-1.1.6 → konfai-1.1.8}/tests/test_config.py +0 -0
- {konfai-1.1.6 → konfai-1.1.8}/tests/test_dataset.py +0 -0
|
@@ -24,11 +24,11 @@ from konfai.data.augmentation import DataAugmentationsList
|
|
|
24
24
|
class GroupTransform:
|
|
25
25
|
|
|
26
26
|
@config()
|
|
27
|
-
def __init__(self,
|
|
28
|
-
|
|
27
|
+
def __init__(self, transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()},
|
|
28
|
+
patch_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()},
|
|
29
29
|
isInput: bool = True) -> None:
|
|
30
|
-
self._pre_transforms =
|
|
31
|
-
self._post_transforms =
|
|
30
|
+
self._pre_transforms = transforms
|
|
31
|
+
self._post_transforms = patch_transforms
|
|
32
32
|
self.pre_transforms : list[Transform] = []
|
|
33
33
|
self.post_transforms : list[Transform] = []
|
|
34
34
|
self.isInput = isInput
|
|
@@ -37,7 +37,7 @@ class GroupTransform:
|
|
|
37
37
|
if self._pre_transforms is not None:
|
|
38
38
|
if isinstance(self._pre_transforms, dict):
|
|
39
39
|
for classpath, transform in self._pre_transforms.items():
|
|
40
|
-
transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.
|
|
40
|
+
transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.transforms".format(KONFAI_ROOT(), group_src, group_dest))
|
|
41
41
|
transform.setDatasets(datasets)
|
|
42
42
|
self.pre_transforms.append(transform)
|
|
43
43
|
else:
|
|
@@ -48,7 +48,7 @@ class GroupTransform:
|
|
|
48
48
|
if self._post_transforms is not None:
|
|
49
49
|
if isinstance(self._post_transforms, dict):
|
|
50
50
|
for classpath, transform in self._post_transforms.items():
|
|
51
|
-
transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.
|
|
51
|
+
transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.patch_transforms".format(KONFAI_ROOT(), group_src, group_dest))
|
|
52
52
|
transform.setDatasets(datasets)
|
|
53
53
|
self.post_transforms.append(transform)
|
|
54
54
|
else:
|
|
@@ -65,9 +65,8 @@ class GroupTransform:
|
|
|
65
65
|
class GroupTransformMetric(GroupTransform):
|
|
66
66
|
|
|
67
67
|
@config()
|
|
68
|
-
def __init__(self,
|
|
69
|
-
|
|
70
|
-
super().__init__(pre_transforms, post_transforms)
|
|
68
|
+
def __init__(self, transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()}):
|
|
69
|
+
super().__init__(transforms, None)
|
|
71
70
|
|
|
72
71
|
class Group(dict[str, GroupTransform]):
|
|
73
72
|
|
|
@@ -196,10 +195,9 @@ class DatasetIter(data.Dataset):
|
|
|
196
195
|
|
|
197
196
|
class Subset():
|
|
198
197
|
|
|
199
|
-
def __init__(self, subset: Union[str, list[int], list[str], None] = None, shuffle: bool = True
|
|
198
|
+
def __init__(self, subset: Union[str, list[int], list[str], None] = None, shuffle: bool = True) -> None:
|
|
200
199
|
self.subset = subset
|
|
201
200
|
self.shuffle = shuffle
|
|
202
|
-
self.filter: list[tuple[str]] = [tuple(a.split(":")) for a in filter.split(";")] if filter is not None else None
|
|
203
201
|
|
|
204
202
|
def __call__(self, names: list[str], infos: list[dict[str, tuple[np.ndarray, Attribute]]]) -> set[str]:
|
|
205
203
|
inter_name = set(names[0])
|
|
@@ -207,16 +205,7 @@ class Subset():
|
|
|
207
205
|
inter_name = inter_name.intersection(set(n))
|
|
208
206
|
names = sorted(list(inter_name))
|
|
209
207
|
|
|
210
|
-
|
|
211
|
-
if self.filter is not None:
|
|
212
|
-
for name in names:
|
|
213
|
-
for info in infos:
|
|
214
|
-
if all([info[name][1].isInfo(f[0], f[1]) for f in self.filter]):
|
|
215
|
-
names_filtred.append(name)
|
|
216
|
-
continue
|
|
217
|
-
else:
|
|
218
|
-
names_filtred = names
|
|
219
|
-
size = len(names_filtred)
|
|
208
|
+
size = len(names)
|
|
220
209
|
index = []
|
|
221
210
|
if self.subset is None:
|
|
222
211
|
index = list(range(0, size))
|
|
@@ -230,7 +219,7 @@ class Subset():
|
|
|
230
219
|
for name in f:
|
|
231
220
|
train_names.append(name.strip())
|
|
232
221
|
index = []
|
|
233
|
-
for i, name in enumerate(
|
|
222
|
+
for i, name in enumerate(names):
|
|
234
223
|
if name in train_names:
|
|
235
224
|
index.append(i)
|
|
236
225
|
elif self.subset.startswith("~") and os.path.exists(self.subset[1:]):
|
|
@@ -239,7 +228,7 @@ class Subset():
|
|
|
239
228
|
for name in f:
|
|
240
229
|
exclude_names.append(name.strip())
|
|
241
230
|
index = []
|
|
242
|
-
for i, name in enumerate(
|
|
231
|
+
for i, name in enumerate(names):
|
|
243
232
|
if name not in exclude_names:
|
|
244
233
|
index.append(i)
|
|
245
234
|
|
|
@@ -252,26 +241,27 @@ class Subset():
|
|
|
252
241
|
index = self.subset
|
|
253
242
|
if isinstance(self.subset[0], str):
|
|
254
243
|
index = []
|
|
255
|
-
for i, name in enumerate(
|
|
244
|
+
for i, name in enumerate(names):
|
|
256
245
|
if name in self.subset:
|
|
257
246
|
index.append(i)
|
|
258
247
|
if self.shuffle:
|
|
259
248
|
index = random.sample(index, len(index))
|
|
260
|
-
return set([
|
|
249
|
+
return set([names[i] for i in index])
|
|
261
250
|
|
|
262
251
|
def __str__(self):
|
|
263
|
-
return "Subset : " + str(self.subset) + " shuffle : "+ str(self.shuffle)
|
|
252
|
+
return "Subset : " + str(self.subset) + " shuffle : "+ str(self.shuffle)
|
|
253
|
+
|
|
264
254
|
class TrainSubset(Subset):
|
|
265
255
|
|
|
266
256
|
@config()
|
|
267
|
-
def __init__(self, subset: Union[str, list[int], list[str], None] = None, shuffle: bool = True
|
|
268
|
-
super().__init__(subset, shuffle
|
|
257
|
+
def __init__(self, subset: Union[str, list[int], list[str], None] = None, shuffle: bool = True) -> None:
|
|
258
|
+
super().__init__(subset, shuffle)
|
|
269
259
|
|
|
270
260
|
class PredictionSubset(Subset):
|
|
271
261
|
|
|
272
262
|
@config()
|
|
273
|
-
def __init__(self, subset: Union[str, list[int], list[str], None] = None
|
|
274
|
-
super().__init__(subset, False
|
|
263
|
+
def __init__(self, subset: Union[str, list[int], list[str], None] = None) -> None:
|
|
264
|
+
super().__init__(subset, False)
|
|
275
265
|
|
|
276
266
|
class Data(ABC):
|
|
277
267
|
|
|
@@ -280,7 +270,6 @@ class Data(ABC):
|
|
|
280
270
|
patch : Union[DatasetPatch, None],
|
|
281
271
|
use_cache : bool,
|
|
282
272
|
subset : Subset,
|
|
283
|
-
num_workers : int,
|
|
284
273
|
batch_size : int,
|
|
285
274
|
validation: Union[float, str, list[int], list[str], None] = None,
|
|
286
275
|
inlineAugmentations: bool = False,
|
|
@@ -293,7 +282,7 @@ class Data(ABC):
|
|
|
293
282
|
self.dataAugmentationsList = dataAugmentationsList
|
|
294
283
|
self.batch_size = batch_size
|
|
295
284
|
self.dataSet_args = dict(groups_src=self.groups_src, inlineAugmentations=inlineAugmentations, dataAugmentationsList = list(self.dataAugmentationsList.values()), use_cache = use_cache, buffer_size=batch_size+1, patch_size=self.patch.patch_size if self.patch is not None else None, overlap=self.patch.overlap if self.patch is not None else None)
|
|
296
|
-
self.dataLoader_args = dict(num_workers=
|
|
285
|
+
self.dataLoader_args = dict(num_workers=int(os.environ["KONFAI_WORKERS"]) if use_cache else 0, pin_memory=True)
|
|
297
286
|
self.data : list[list[dict[str, list[DatasetManager]]], dict[str, list[DatasetManager]]] = []
|
|
298
287
|
self.map : list[list[list[tuple[int, int, int]]], list[tuple[int, int, int]]] = []
|
|
299
288
|
self.datasets: dict[str, Dataset] = {}
|
|
@@ -547,10 +536,9 @@ class DataTrain(Data):
|
|
|
547
536
|
patch : Union[DatasetPatch, None] = DatasetPatch(),
|
|
548
537
|
use_cache : bool = True,
|
|
549
538
|
subset : Union[TrainSubset, dict[str, TrainSubset]] = TrainSubset(),
|
|
550
|
-
num_workers : int = 4,
|
|
551
539
|
batch_size : int = 1,
|
|
552
540
|
validation : Union[float, str, list[int], list[str]] = 0.2) -> None:
|
|
553
|
-
super().__init__(dataset_filenames, groups_src, patch, use_cache, subset,
|
|
541
|
+
super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, batch_size, validation, inlineAugmentations, augmentations if augmentations else {})
|
|
554
542
|
|
|
555
543
|
class DataPrediction(Data):
|
|
556
544
|
|
|
@@ -560,10 +548,9 @@ class DataPrediction(Data):
|
|
|
560
548
|
augmentations : Union[dict[str, DataAugmentationsList], None] = {"DataAugmentation_0" : DataAugmentationsList()},
|
|
561
549
|
patch : Union[DatasetPatch, None] = DatasetPatch(),
|
|
562
550
|
subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
|
|
563
|
-
num_workers : int = 4,
|
|
564
551
|
batch_size : int = 1) -> None:
|
|
565
552
|
|
|
566
|
-
super().__init__(dataset_filenames, groups_src, patch, False, subset,
|
|
553
|
+
super().__init__(dataset_filenames, groups_src, patch, False, subset, batch_size, dataAugmentationsList=augmentations if augmentations else {})
|
|
567
554
|
|
|
568
555
|
class DataMetric(Data):
|
|
569
556
|
|
|
@@ -571,7 +558,6 @@ class DataMetric(Data):
|
|
|
571
558
|
def __init__(self, dataset_filenames : list[str] = ["default:./Dataset"],
|
|
572
559
|
groups_src : dict[str, GroupMetric] = {"default" : GroupMetric()},
|
|
573
560
|
subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
|
|
574
|
-
validation: Union[str, None] = None
|
|
575
|
-
num_workers : int = 4) -> None:
|
|
561
|
+
validation: Union[str, None] = None) -> None:
|
|
576
562
|
|
|
577
|
-
super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset,
|
|
563
|
+
super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, batch_size=1, validation=validation)
|
|
@@ -81,7 +81,6 @@ class PathCombine(ABC):
|
|
|
81
81
|
|
|
82
82
|
class Mean(PathCombine):
|
|
83
83
|
|
|
84
|
-
@config("Mean")
|
|
85
84
|
def __init__(self) -> None:
|
|
86
85
|
super().__init__()
|
|
87
86
|
|
|
@@ -90,7 +89,6 @@ class Mean(PathCombine):
|
|
|
90
89
|
|
|
91
90
|
class Cosinus(PathCombine):
|
|
92
91
|
|
|
93
|
-
@config("Cosinus")
|
|
94
92
|
def __init__(self) -> None:
|
|
95
93
|
super().__init__()
|
|
96
94
|
|
|
@@ -144,7 +142,7 @@ class Accumulator():
|
|
|
144
142
|
|
|
145
143
|
class Patch(ABC):
|
|
146
144
|
|
|
147
|
-
def __init__(self, patch_size: list[int], overlap: Union[int, None],
|
|
145
|
+
def __init__(self, patch_size: list[int], overlap: Union[int, None], padValue: float = 0, extend_slice: int = 0) -> None:
|
|
148
146
|
self.patch_size = patch_size
|
|
149
147
|
self.overlap = overlap
|
|
150
148
|
if isinstance(self.overlap, int):
|
|
@@ -152,15 +150,8 @@ class Patch(ABC):
|
|
|
152
150
|
self.overlap = None
|
|
153
151
|
self._patch_slices : dict[int, list[tuple[slice]]] = {}
|
|
154
152
|
self._nb_patch_per_dim: dict[int, list[tuple[int, bool]]] = {}
|
|
155
|
-
self.path_mask = path_mask
|
|
156
|
-
self.mask = None
|
|
157
153
|
self.padValue = padValue
|
|
158
154
|
self.extend_slice = extend_slice
|
|
159
|
-
if self.path_mask is not None:
|
|
160
|
-
if os.path.exists(self.path_mask):
|
|
161
|
-
self.mask = torch.tensor(sitk.GetArrayFromImage(sitk.ReadImage(self.path_mask)))
|
|
162
|
-
else:
|
|
163
|
-
raise NameError('Mask file not found')
|
|
164
155
|
|
|
165
156
|
def load(self, shape : dict[int, list[int]], a: int = 0) -> None:
|
|
166
157
|
self._patch_slices[a], self._nb_patch_per_dim[a] = get_patch_slices_from_shape(self.patch_size, shape, self.overlap)
|
|
@@ -199,9 +190,6 @@ class Patch(ABC):
|
|
|
199
190
|
padding.append(p)
|
|
200
191
|
|
|
201
192
|
data_sliced = F.pad(data_sliced, tuple(padding), "constant", 0 if data_sliced.dtype == torch.uint8 and self.padValue < 0 else self.padValue)
|
|
202
|
-
if self.mask is not None:
|
|
203
|
-
outside = torch.ones_like(data_sliced)*(0 if data_sliced.dtype == torch.uint8 and self.padValue < 0 else self.padValue)
|
|
204
|
-
data_sliced = torch.where(self.mask == 0, outside, data_sliced)
|
|
205
193
|
|
|
206
194
|
for d in [i for i, v in enumerate(reversed(self.patch_size)) if v == 1]:
|
|
207
195
|
data_sliced = torch.squeeze(data_sliced, dim = len(data_sliced.shape)-d-1)
|
|
@@ -213,14 +201,14 @@ class Patch(ABC):
|
|
|
213
201
|
class DatasetPatch(Patch):
|
|
214
202
|
|
|
215
203
|
@config("Patch")
|
|
216
|
-
def __init__(self, patch_size : list[int] = [128, 128, 128], overlap : Union[int, None] = None,
|
|
217
|
-
super().__init__(patch_size, overlap,
|
|
204
|
+
def __init__(self, patch_size : list[int] = [128, 128, 128], overlap : Union[int, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
|
|
205
|
+
super().__init__(patch_size, overlap, padValue, extend_slice)
|
|
218
206
|
|
|
219
207
|
class ModelPatch(Patch):
|
|
220
208
|
|
|
221
209
|
@config("Patch")
|
|
222
|
-
def __init__(self, patch_size : list[int] = [128, 128, 128], overlap : Union[int, None] = None, patchCombine: Union[str, None] = None,
|
|
223
|
-
super().__init__(patch_size, overlap,
|
|
210
|
+
def __init__(self, patch_size : list[int] = [128, 128, 128], overlap : Union[int, None] = None, patchCombine: Union[str, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
|
|
211
|
+
super().__init__(patch_size, overlap, padValue, extend_slice)
|
|
224
212
|
self.patchCombine = patchCombine
|
|
225
213
|
|
|
226
214
|
def disassemble(self, *dataList: torch.Tensor) -> Iterator[list[torch.Tensor]]:
|
|
@@ -247,7 +235,7 @@ class DatasetManager():
|
|
|
247
235
|
for transformFunction in pre_transforms:
|
|
248
236
|
_shape = transformFunction.transformShape(_shape, cache_attribute)
|
|
249
237
|
|
|
250
|
-
self.patch = DatasetPatch(patch_size=patch.patch_size, overlap=patch.overlap,
|
|
238
|
+
self.patch = DatasetPatch(patch_size=patch.patch_size, overlap=patch.overlap, padValue=patch.padValue, extend_slice=patch.extend_slice) if patch else DatasetPatch(_shape)
|
|
251
239
|
self.patch.load(_shape, 0)
|
|
252
240
|
self.shape = _shape
|
|
253
241
|
self.dataAugmentationsList = dataAugmentationsList
|
|
@@ -314,7 +302,7 @@ class DatasetManager():
|
|
|
314
302
|
dataset.write(self.group_dest, self.name, data.numpy(), self.cache_attributes[0])
|
|
315
303
|
self.data : list[torch.Tensor] = list()
|
|
316
304
|
self.data.append(data)
|
|
317
|
-
|
|
305
|
+
|
|
318
306
|
for i in range(len(self.cache_attributes)-1):
|
|
319
307
|
self.cache_attributes[i+1].update(self.cache_attributes[0])
|
|
320
308
|
self.loaded = True
|
|
@@ -259,7 +259,6 @@ class ResampleToShape(Resample):
|
|
|
259
259
|
for i, s in enumerate(self.shape):
|
|
260
260
|
if s == 0:
|
|
261
261
|
new_shape[i] = shape[i]
|
|
262
|
-
print(new_shape)
|
|
263
262
|
return new_shape
|
|
264
263
|
|
|
265
264
|
def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -411,7 +410,7 @@ class Gradient(Transform):
|
|
|
411
410
|
def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
412
411
|
return input
|
|
413
412
|
|
|
414
|
-
class
|
|
413
|
+
class Argmax(Transform):
|
|
415
414
|
|
|
416
415
|
def __init__(self, dim: int = 0) -> None:
|
|
417
416
|
self.dim = dim
|
|
@@ -9,7 +9,6 @@ from scipy import linalg
|
|
|
9
9
|
|
|
10
10
|
import torch.nn.functional as F
|
|
11
11
|
import os
|
|
12
|
-
import tqdm
|
|
13
12
|
|
|
14
13
|
from typing import Callable, Union
|
|
15
14
|
from functools import partial
|
|
@@ -19,10 +18,11 @@ import copy
|
|
|
19
18
|
from abc import abstractmethod
|
|
20
19
|
|
|
21
20
|
from konfai.utils.config import config
|
|
22
|
-
from konfai.utils.utils import _getModule
|
|
21
|
+
from konfai.utils.utils import _getModule, download_url
|
|
23
22
|
from konfai.data.patching import ModelPatch
|
|
24
23
|
from konfai.network.blocks import LatentDistribution
|
|
25
24
|
from konfai.network.network import ModelLoader, Network
|
|
25
|
+
from tqdm import tqdm
|
|
26
26
|
|
|
27
27
|
modelsRegister = {}
|
|
28
28
|
|
|
@@ -122,7 +122,7 @@ class LPIPS(MaskedLoss):
|
|
|
122
122
|
|
|
123
123
|
patchIterator = datasetPatch.disassemble(LPIPS.normalize(x), LPIPS.normalize(y))
|
|
124
124
|
loss = 0
|
|
125
|
-
with tqdm
|
|
125
|
+
with tqdm(iterable = enumerate(patchIterator), leave=False, total=datasetPatch.getSize(0)) as batch_iter:
|
|
126
126
|
for i, patch_input in batch_iter:
|
|
127
127
|
real, fake = LPIPS.preprocessing(patch_input[0]), LPIPS.preprocessing(patch_input[1])
|
|
128
128
|
loss += loss_fn_alex(real, fake).item()
|
|
@@ -146,9 +146,11 @@ class Dice(Criterion):
|
|
|
146
146
|
target = self.flatten(target)
|
|
147
147
|
return (2.*(input * target).sum() + self.smooth)/(input.sum() + target.sum() + self.smooth)
|
|
148
148
|
|
|
149
|
-
|
|
150
149
|
def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
|
|
151
|
-
target =
|
|
150
|
+
target = F.interpolate(
|
|
151
|
+
targets[0],
|
|
152
|
+
output.shape[2:],
|
|
153
|
+
mode="nearest")
|
|
152
154
|
if output.shape[1] == 1:
|
|
153
155
|
output = F.one_hot(output.type(torch.int64), num_classes=int(torch.max(output).item()+1)).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float()
|
|
154
156
|
target = F.one_hot(target.type(torch.int64), num_classes=output.shape[1]).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float().squeeze(2)
|
|
@@ -242,7 +244,7 @@ class PerceptualLoss(Criterion):
|
|
|
242
244
|
|
|
243
245
|
class Module():
|
|
244
246
|
|
|
245
|
-
@config(
|
|
247
|
+
@config()
|
|
246
248
|
def __init__(self, losses: dict[str, float] = {"Gram": 1, "torch_nn_L1Loss": 1}) -> None:
|
|
247
249
|
self.losses = losses
|
|
248
250
|
self.DL_args = os.environ['KONFAI_CONFIG_PATH'] if "KONFAI_CONFIG_PATH" in os.environ else ""
|
|
@@ -413,16 +415,19 @@ class FocalLoss(Criterion):
|
|
|
413
415
|
self.gamma = gamma
|
|
414
416
|
self.reduction = reduction
|
|
415
417
|
|
|
416
|
-
def forward(self, output: torch.Tensor, targets: list[torch.Tensor]) -> torch.Tensor:
|
|
417
|
-
target =
|
|
418
|
+
def forward(self, output: torch.Tensor, *targets: list[torch.Tensor]) -> torch.Tensor:
|
|
419
|
+
target = F.interpolate(
|
|
420
|
+
targets[0],
|
|
421
|
+
output.shape[2:],
|
|
422
|
+
mode="nearest").long()
|
|
418
423
|
|
|
419
424
|
logpt = F.log_softmax(output, dim=1)
|
|
420
425
|
pt = torch.exp(logpt)
|
|
421
426
|
|
|
422
|
-
logpt = logpt.gather(1, target
|
|
423
|
-
pt = pt.gather(1, target
|
|
427
|
+
logpt = logpt.gather(1, target)
|
|
428
|
+
pt = pt.gather(1, target)
|
|
424
429
|
|
|
425
|
-
at = self.alpha[target].unsqueeze(1)
|
|
430
|
+
at = self.alpha.to(target.device)[target].unsqueeze(1)
|
|
426
431
|
loss = -at * ((1 - pt) ** self.gamma) * logpt
|
|
427
432
|
|
|
428
433
|
if self.reduction == "mean":
|
|
@@ -508,4 +513,72 @@ class MutualInformationLoss(torch.nn.Module):
|
|
|
508
513
|
pab = torch.bmm(wa.permute(0, 2, 1), wb.to(wa)).div(wa.shape[1]) # (batch, num_bins, num_bins)
|
|
509
514
|
papb = torch.bmm(pa.permute(0, 2, 1), pb.to(pa)) # (batch, num_bins, num_bins)
|
|
510
515
|
mi = torch.sum(pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr), dim=(1, 2)) # (batch)
|
|
511
|
-
return torch.mean(mi).neg() # average over the batch and channel ndims
|
|
516
|
+
return torch.mean(mi).neg() # average over the batch and channel ndims
|
|
517
|
+
|
|
518
|
+
class FOCUS(Criterion): #Feature-Oriented Comparison for Unpaired Synthesis
|
|
519
|
+
|
|
520
|
+
def __init__(self, model_name: str, shape: list[int] = [0,0], in_channels: int = 1, losses: dict[str, list[float]] = {"Gram": 1, "torch_nn_L1Loss": 1}) -> None:
|
|
521
|
+
super().__init__()
|
|
522
|
+
if model_name is None:
|
|
523
|
+
return
|
|
524
|
+
self.shape = shape
|
|
525
|
+
self.in_channels = in_channels
|
|
526
|
+
self.losses: dict[torch.nn.Module, list[float]] = {}
|
|
527
|
+
for loss, weights in losses.items():
|
|
528
|
+
module, name = _getModule(loss, "konfai.metric.measure")
|
|
529
|
+
self.losses[config(os.environ['KONFAI_CONFIG_PATH'])(getattr(importlib.import_module(module), name))(config=None)] = weights
|
|
530
|
+
|
|
531
|
+
self.model_path = download_url(model_name, "https://huggingface.co/VBoussot/impact-torchscript-models/resolve/main/")
|
|
532
|
+
self.model: torch.nn.Module = torch.jit.load(self.model_path)
|
|
533
|
+
self.dim = len(self.shape)
|
|
534
|
+
self.shape = self.shape if all([s > 0 for s in self.shape]) else None
|
|
535
|
+
self.modules_loss: dict[str, dict[torch.nn.Module, float]] = {}
|
|
536
|
+
|
|
537
|
+
try:
|
|
538
|
+
dummy_input = torch.zeros((1, self.in_channels, *(self.shape if self.shape else [224] * self.dim)))
|
|
539
|
+
out = self.model(dummy_input)
|
|
540
|
+
if not isinstance(out, (list, tuple)):
|
|
541
|
+
raise TypeError(f"Expected model output to be a list or tuple, but got {type(out)}.")
|
|
542
|
+
if len(self.weight) != len(out):
|
|
543
|
+
raise ValueError(f"Mismatch between number of weights ({len(self.weight)}) and model outputs ({len(out)}).")
|
|
544
|
+
except Exception as e:
|
|
545
|
+
msg = (
|
|
546
|
+
f"[Model Sanity Check Failed]\n"
|
|
547
|
+
f"Input shape attempted: {dummy_input.shape}\n"
|
|
548
|
+
f"Expected output length: {len(self.weight)}\n"
|
|
549
|
+
f"Error: {type(e).__name__}: {e}"
|
|
550
|
+
)
|
|
551
|
+
raise RuntimeError(msg) from e
|
|
552
|
+
self.model = None
|
|
553
|
+
|
|
554
|
+
def preprocessing(self, input: torch.Tensor) -> torch.Tensor:
|
|
555
|
+
if self.shape is not None and not all([input.shape[-i-1] == size for i, size in enumerate(reversed(self.shape[2:]))]):
|
|
556
|
+
input = torch.nn.functional.interpolate(input, mode=self.mode, size=tuple(self.shape), align_corners=False).type(torch.float32)
|
|
557
|
+
if input.shape[1] != self.in_channels:
|
|
558
|
+
input = input.repeat(tuple([1,3] + [1 for _ in range(self.dim)]))
|
|
559
|
+
return input
|
|
560
|
+
|
|
561
|
+
def _compute(self, output: torch.Tensor, targets: list[torch.Tensor]) -> torch.Tensor:
|
|
562
|
+
loss = torch.zeros((1), requires_grad = True).to(output.device, non_blocking=False).type(torch.float32)
|
|
563
|
+
output = self.preprocessing(output)
|
|
564
|
+
targets = [self.preprocessing(target) for target in targets]
|
|
565
|
+
self.model.to(output.device)
|
|
566
|
+
for zipped_output in zip(self.model(output), *[self.model(target) for target in targets]):
|
|
567
|
+
output_features = zipped_output[0]
|
|
568
|
+
targets_features = zipped_output[1:]
|
|
569
|
+
for target_features, (loss_function, weights) in zip(targets_features, self.losses.items()):
|
|
570
|
+
for output_feature, target_feature, weight in zip(output_features ,target_features, weights):
|
|
571
|
+
loss += weights*loss_function(output_feature, target_feature)
|
|
572
|
+
return loss
|
|
573
|
+
|
|
574
|
+
def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
|
|
575
|
+
if self.model is None:
|
|
576
|
+
self.model = torch.jit.load(self.model_path)
|
|
577
|
+
loss = torch.zeros((1), requires_grad = True).to(output.device, non_blocking=False).type(torch.float32)
|
|
578
|
+
if len(output.shape) == 5 and self.dim == 2:
|
|
579
|
+
for i in range(output.shape[2]):
|
|
580
|
+
loss = loss + self._compute(output[:, :, i, ...], [t[:, :, i, ...] for t in targets])
|
|
581
|
+
loss/=output.shape[2]
|
|
582
|
+
else:
|
|
583
|
+
loss = self._compute(output, targets)
|
|
584
|
+
return loss.to(output)
|
|
@@ -6,7 +6,6 @@ from functools import partial
|
|
|
6
6
|
|
|
7
7
|
class Scheduler():
|
|
8
8
|
|
|
9
|
-
@config("Scheduler")
|
|
10
9
|
def __init__(self, start_value: float) -> None:
|
|
11
10
|
self.baseValue = float(start_value)
|
|
12
11
|
self.it = 0
|
|
@@ -20,7 +19,6 @@ class Scheduler():
|
|
|
20
19
|
|
|
21
20
|
class Constant(Scheduler):
|
|
22
21
|
|
|
23
|
-
@config("Constant")
|
|
24
22
|
def __init__(self, value: float = 1):
|
|
25
23
|
super().__init__(value)
|
|
26
24
|
|
|
@@ -29,19 +27,7 @@ class Constant(Scheduler):
|
|
|
29
27
|
|
|
30
28
|
class CosineAnnealing(Scheduler):
|
|
31
29
|
|
|
32
|
-
|
|
33
|
-
def __init__(self, start_value: float, eta_min: float = 0.00001, T_max: int = 100):
|
|
34
|
-
super().__init__(start_value)
|
|
35
|
-
self.eta_min = eta_min
|
|
36
|
-
self.T_max = T_max
|
|
37
|
-
|
|
38
|
-
def get_value(self):
|
|
39
|
-
return self.eta_min + (self.baseValue - self.eta_min) *(1 + np.cos(self.it * torch.pi / self.T_max)) / 2
|
|
40
|
-
|
|
41
|
-
class CosineAnnealing(Scheduler):
|
|
42
|
-
|
|
43
|
-
@config("CosineAnnealing")
|
|
44
|
-
def __init__(self, start_value: float, eta_min: float = 0.00001, T_max: int = 100):
|
|
30
|
+
def __init__(self, start_value: float = 1, eta_min: float = 0.00001, T_max: int = 100):
|
|
45
31
|
super().__init__(start_value)
|
|
46
32
|
self.eta_min = eta_min
|
|
47
33
|
self.T_max = T_max
|
|
@@ -52,7 +38,7 @@ class CosineAnnealing(Scheduler):
|
|
|
52
38
|
class Warmup(torch.optim.lr_scheduler.LambdaLR):
|
|
53
39
|
|
|
54
40
|
def warmup(warmup_steps: int, step: int) -> float:
|
|
55
|
-
return min(1.0, step / warmup_steps)
|
|
41
|
+
return min(1.0, (step+1) / (warmup_steps+1))
|
|
56
42
|
|
|
57
43
|
def __init__(self, optimizer: torch.optim.Optimizer, warmup_steps: int = 10, last_epoch=-1, verbose="deprecated"):
|
|
58
44
|
super().__init__(optimizer, partial(Warmup.warmup, warmup_steps), last_epoch, verbose)
|
|
@@ -157,10 +157,9 @@ class Head(network.ModuleArgsDict):
|
|
|
157
157
|
|
|
158
158
|
class ConvNeXt(network.Network):
|
|
159
159
|
|
|
160
|
-
@config("ConvNeXt")
|
|
161
160
|
def __init__( self,
|
|
162
161
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
163
|
-
schedulers
|
|
162
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
164
163
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
165
164
|
patch : ModelPatch = ModelPatch(),
|
|
166
165
|
dim : int = 3,
|
|
@@ -98,10 +98,9 @@ class Head(network.ModuleArgsDict):
|
|
|
98
98
|
|
|
99
99
|
class ResNet(network.Network):
|
|
100
100
|
|
|
101
|
-
@config("ResNet")
|
|
102
101
|
def __init__( self,
|
|
103
102
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
104
|
-
schedulers
|
|
103
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
105
104
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
106
105
|
patch : ModelPatch = ModelPatch(),
|
|
107
106
|
dim : int = 3,
|
|
@@ -118,10 +118,9 @@ class Generator(network.Network):
|
|
|
118
118
|
self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1))
|
|
119
119
|
self.add_module("Tanh", torch.nn.Tanh())
|
|
120
120
|
|
|
121
|
-
@config("Generator")
|
|
122
121
|
def __init__(self,
|
|
123
122
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
124
|
-
schedulers
|
|
123
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
125
124
|
patch : ModelPatch = ModelPatch(),
|
|
126
125
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
127
126
|
channels: list[int]=[1, 64, 128, 256, 512, 1024],
|
|
@@ -175,10 +175,9 @@ class DDPM(network.Network):
|
|
|
175
175
|
def forward(self, input: torch.Tensor, eta: torch.Tensor, eta_hat: torch.Tensor, t: int) -> torch.Tensor:
|
|
176
176
|
return torch.concat((self.v(input, eta, t), self.v(input, eta_hat, t)), dim=1)
|
|
177
177
|
|
|
178
|
-
@config("DDPM")
|
|
179
178
|
def __init__( self,
|
|
180
179
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
181
|
-
schedulers
|
|
180
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
182
181
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
183
182
|
patch : Union[ModelPatch, None] = None,
|
|
184
183
|
train_noise_step: int = 1000,
|