konfai 1.1.5__tar.gz → 1.1.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {konfai-1.1.5 → konfai-1.1.7}/PKG-INFO +1 -1
- {konfai-1.1.5 → konfai-1.1.7}/konfai/data/augmentation.py +1 -2
- {konfai-1.1.5 → konfai-1.1.7}/konfai/data/data_manager.py +26 -40
- {konfai-1.1.5 → konfai-1.1.7}/konfai/data/patching.py +7 -19
- {konfai-1.1.5 → konfai-1.1.7}/konfai/data/transform.py +23 -21
- {konfai-1.1.5 → konfai-1.1.7}/konfai/evaluator.py +1 -1
- {konfai-1.1.5 → konfai-1.1.7}/konfai/main.py +0 -2
- {konfai-1.1.5 → konfai-1.1.7}/konfai/metric/measure.py +86 -17
- {konfai-1.1.5 → konfai-1.1.7}/konfai/metric/schedulers.py +7 -12
- {konfai-1.1.5 → konfai-1.1.7}/konfai/models/classification/convNeXt.py +1 -2
- {konfai-1.1.5 → konfai-1.1.7}/konfai/models/classification/resnet.py +1 -2
- {konfai-1.1.5 → konfai-1.1.7}/konfai/models/generation/cStyleGan.py +1 -2
- {konfai-1.1.5 → konfai-1.1.7}/konfai/models/generation/ddpm.py +1 -2
- {konfai-1.1.5 → konfai-1.1.7}/konfai/models/generation/diffusionGan.py +10 -23
- {konfai-1.1.5 → konfai-1.1.7}/konfai/models/generation/gan.py +3 -6
- {konfai-1.1.5 → konfai-1.1.7}/konfai/models/generation/vae.py +0 -2
- {konfai-1.1.5 → konfai-1.1.7}/konfai/models/registration/registration.py +1 -2
- {konfai-1.1.5 → konfai-1.1.7}/konfai/models/representation/representation.py +1 -2
- konfai-1.1.7/konfai/models/segmentation/NestedUNet.py +114 -0
- {konfai-1.1.5 → konfai-1.1.7}/konfai/models/segmentation/UNet.py +1 -2
- {konfai-1.1.5 → konfai-1.1.7}/konfai/network/blocks.py +8 -6
- {konfai-1.1.5 → konfai-1.1.7}/konfai/network/network.py +54 -47
- {konfai-1.1.5 → konfai-1.1.7}/konfai/predictor.py +55 -76
- {konfai-1.1.5 → konfai-1.1.7}/konfai/trainer.py +5 -4
- {konfai-1.1.5 → konfai-1.1.7}/konfai/utils/utils.py +41 -3
- {konfai-1.1.5 → konfai-1.1.7}/konfai.egg-info/PKG-INFO +1 -1
- {konfai-1.1.5 → konfai-1.1.7}/pyproject.toml +1 -1
- konfai-1.1.5/konfai/models/segmentation/NestedUNet.py +0 -115
- {konfai-1.1.5 → konfai-1.1.7}/LICENSE +0 -0
- {konfai-1.1.5 → konfai-1.1.7}/README.md +0 -0
- {konfai-1.1.5 → konfai-1.1.7}/konfai/__init__.py +0 -0
- {konfai-1.1.5 → konfai-1.1.7}/konfai/data/__init__.py +0 -0
- {konfai-1.1.5 → konfai-1.1.7}/konfai/metric/__init__.py +0 -0
- {konfai-1.1.5 → konfai-1.1.7}/konfai/network/__init__.py +0 -0
- {konfai-1.1.5 → konfai-1.1.7}/konfai/utils/ITK.py +0 -0
- {konfai-1.1.5 → konfai-1.1.7}/konfai/utils/__init__.py +0 -0
- {konfai-1.1.5 → konfai-1.1.7}/konfai/utils/config.py +0 -0
- {konfai-1.1.5 → konfai-1.1.7}/konfai/utils/dataset.py +0 -0
- {konfai-1.1.5 → konfai-1.1.7}/konfai/utils/registration.py +0 -0
- {konfai-1.1.5 → konfai-1.1.7}/konfai.egg-info/SOURCES.txt +0 -0
- {konfai-1.1.5 → konfai-1.1.7}/konfai.egg-info/dependency_links.txt +0 -0
- {konfai-1.1.5 → konfai-1.1.7}/konfai.egg-info/entry_points.txt +0 -0
- {konfai-1.1.5 → konfai-1.1.7}/konfai.egg-info/requires.txt +0 -0
- {konfai-1.1.5 → konfai-1.1.7}/konfai.egg-info/top_level.txt +0 -0
- {konfai-1.1.5 → konfai-1.1.7}/setup.cfg +0 -0
- {konfai-1.1.5 → konfai-1.1.7}/tests/test_config.py +0 -0
- {konfai-1.1.5 → konfai-1.1.7}/tests/test_dataset.py +0 -0
|
@@ -57,7 +57,7 @@ class DataAugmentationsList():
|
|
|
57
57
|
|
|
58
58
|
def load(self, key: str, datasets: list[Dataset]):
|
|
59
59
|
for augmentation, prob in self.dataAugmentationsLoader.items():
|
|
60
|
-
module, name = _getModule(augmentation, "data.augmentation")
|
|
60
|
+
module, name = _getModule(augmentation, "konfai.data.augmentation")
|
|
61
61
|
dataAugmentation: DataAugmentation = config("{}.Dataset.augmentations.{}.dataAugmentations.{}".format(KONFAI_ROOT(), key, augmentation))(getattr(importlib.import_module(module), name))(config = None)
|
|
62
62
|
dataAugmentation.load(prob.prob)
|
|
63
63
|
dataAugmentation.setDatasets(datasets)
|
|
@@ -264,7 +264,6 @@ class LumaFlip(ColorTransform):
|
|
|
264
264
|
|
|
265
265
|
class HUE(ColorTransform):
|
|
266
266
|
|
|
267
|
-
@config("HUE")
|
|
268
267
|
def __init__(self, hue_max: float, groups: Union[list[str], None] = None) -> None:
|
|
269
268
|
super().__init__(groups)
|
|
270
269
|
self.hue_max = hue_max
|
|
@@ -24,11 +24,11 @@ from konfai.data.augmentation import DataAugmentationsList
|
|
|
24
24
|
class GroupTransform:
|
|
25
25
|
|
|
26
26
|
@config()
|
|
27
|
-
def __init__(self,
|
|
28
|
-
|
|
27
|
+
def __init__(self, transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()},
|
|
28
|
+
patch_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()},
|
|
29
29
|
isInput: bool = True) -> None:
|
|
30
|
-
self._pre_transforms =
|
|
31
|
-
self._post_transforms =
|
|
30
|
+
self._pre_transforms = transforms
|
|
31
|
+
self._post_transforms = patch_transforms
|
|
32
32
|
self.pre_transforms : list[Transform] = []
|
|
33
33
|
self.post_transforms : list[Transform] = []
|
|
34
34
|
self.isInput = isInput
|
|
@@ -37,7 +37,7 @@ class GroupTransform:
|
|
|
37
37
|
if self._pre_transforms is not None:
|
|
38
38
|
if isinstance(self._pre_transforms, dict):
|
|
39
39
|
for classpath, transform in self._pre_transforms.items():
|
|
40
|
-
transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.
|
|
40
|
+
transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.transforms".format(KONFAI_ROOT(), group_src, group_dest))
|
|
41
41
|
transform.setDatasets(datasets)
|
|
42
42
|
self.pre_transforms.append(transform)
|
|
43
43
|
else:
|
|
@@ -48,7 +48,7 @@ class GroupTransform:
|
|
|
48
48
|
if self._post_transforms is not None:
|
|
49
49
|
if isinstance(self._post_transforms, dict):
|
|
50
50
|
for classpath, transform in self._post_transforms.items():
|
|
51
|
-
transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.
|
|
51
|
+
transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.patch_transforms".format(KONFAI_ROOT(), group_src, group_dest))
|
|
52
52
|
transform.setDatasets(datasets)
|
|
53
53
|
self.post_transforms.append(transform)
|
|
54
54
|
else:
|
|
@@ -65,9 +65,8 @@ class GroupTransform:
|
|
|
65
65
|
class GroupTransformMetric(GroupTransform):
|
|
66
66
|
|
|
67
67
|
@config()
|
|
68
|
-
def __init__(self,
|
|
69
|
-
|
|
70
|
-
super().__init__(pre_transforms, post_transforms)
|
|
68
|
+
def __init__(self, transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()}):
|
|
69
|
+
super().__init__(transforms, None)
|
|
71
70
|
|
|
72
71
|
class Group(dict[str, GroupTransform]):
|
|
73
72
|
|
|
@@ -184,7 +183,7 @@ class DatasetIter(data.Dataset):
|
|
|
184
183
|
data = {}
|
|
185
184
|
x, a, p = self.map[index]
|
|
186
185
|
if x not in self._index_cache:
|
|
187
|
-
if
|
|
186
|
+
if len(self._index_cache) >= self.buffer_size and not self.use_cache:
|
|
188
187
|
self._unloadData(self._index_cache[0])
|
|
189
188
|
self._loadData(x)
|
|
190
189
|
|
|
@@ -196,10 +195,9 @@ class DatasetIter(data.Dataset):
|
|
|
196
195
|
|
|
197
196
|
class Subset():
|
|
198
197
|
|
|
199
|
-
def __init__(self, subset: Union[str, list[int], list[str], None] = None, shuffle: bool = True
|
|
198
|
+
def __init__(self, subset: Union[str, list[int], list[str], None] = None, shuffle: bool = True) -> None:
|
|
200
199
|
self.subset = subset
|
|
201
200
|
self.shuffle = shuffle
|
|
202
|
-
self.filter: list[tuple[str]] = [tuple(a.split(":")) for a in filter.split(";")] if filter is not None else None
|
|
203
201
|
|
|
204
202
|
def __call__(self, names: list[str], infos: list[dict[str, tuple[np.ndarray, Attribute]]]) -> set[str]:
|
|
205
203
|
inter_name = set(names[0])
|
|
@@ -207,16 +205,7 @@ class Subset():
|
|
|
207
205
|
inter_name = inter_name.intersection(set(n))
|
|
208
206
|
names = sorted(list(inter_name))
|
|
209
207
|
|
|
210
|
-
|
|
211
|
-
if self.filter is not None:
|
|
212
|
-
for name in names:
|
|
213
|
-
for info in infos:
|
|
214
|
-
if all([info[name][1].isInfo(f[0], f[1]) for f in self.filter]):
|
|
215
|
-
names_filtred.append(name)
|
|
216
|
-
continue
|
|
217
|
-
else:
|
|
218
|
-
names_filtred = names
|
|
219
|
-
size = len(names_filtred)
|
|
208
|
+
size = len(names)
|
|
220
209
|
index = []
|
|
221
210
|
if self.subset is None:
|
|
222
211
|
index = list(range(0, size))
|
|
@@ -230,7 +219,7 @@ class Subset():
|
|
|
230
219
|
for name in f:
|
|
231
220
|
train_names.append(name.strip())
|
|
232
221
|
index = []
|
|
233
|
-
for i, name in enumerate(
|
|
222
|
+
for i, name in enumerate(names):
|
|
234
223
|
if name in train_names:
|
|
235
224
|
index.append(i)
|
|
236
225
|
elif self.subset.startswith("~") and os.path.exists(self.subset[1:]):
|
|
@@ -239,7 +228,7 @@ class Subset():
|
|
|
239
228
|
for name in f:
|
|
240
229
|
exclude_names.append(name.strip())
|
|
241
230
|
index = []
|
|
242
|
-
for i, name in enumerate(
|
|
231
|
+
for i, name in enumerate(names):
|
|
243
232
|
if name not in exclude_names:
|
|
244
233
|
index.append(i)
|
|
245
234
|
|
|
@@ -252,26 +241,27 @@ class Subset():
|
|
|
252
241
|
index = self.subset
|
|
253
242
|
if isinstance(self.subset[0], str):
|
|
254
243
|
index = []
|
|
255
|
-
for i, name in enumerate(
|
|
244
|
+
for i, name in enumerate(names):
|
|
256
245
|
if name in self.subset:
|
|
257
246
|
index.append(i)
|
|
258
247
|
if self.shuffle:
|
|
259
248
|
index = random.sample(index, len(index))
|
|
260
|
-
return set([
|
|
249
|
+
return set([names[i] for i in index])
|
|
261
250
|
|
|
262
251
|
def __str__(self):
|
|
263
|
-
return "Subset : " + str(self.subset) + " shuffle : "+ str(self.shuffle)
|
|
252
|
+
return "Subset : " + str(self.subset) + " shuffle : "+ str(self.shuffle)
|
|
253
|
+
|
|
264
254
|
class TrainSubset(Subset):
|
|
265
255
|
|
|
266
256
|
@config()
|
|
267
|
-
def __init__(self, subset: Union[str, list[int], list[str], None] = None, shuffle: bool = True
|
|
268
|
-
super().__init__(subset, shuffle
|
|
257
|
+
def __init__(self, subset: Union[str, list[int], list[str], None] = None, shuffle: bool = True) -> None:
|
|
258
|
+
super().__init__(subset, shuffle)
|
|
269
259
|
|
|
270
260
|
class PredictionSubset(Subset):
|
|
271
261
|
|
|
272
262
|
@config()
|
|
273
|
-
def __init__(self, subset: Union[str, list[int], list[str], None] = None
|
|
274
|
-
super().__init__(subset, False
|
|
263
|
+
def __init__(self, subset: Union[str, list[int], list[str], None] = None) -> None:
|
|
264
|
+
super().__init__(subset, False)
|
|
275
265
|
|
|
276
266
|
class Data(ABC):
|
|
277
267
|
|
|
@@ -280,7 +270,6 @@ class Data(ABC):
|
|
|
280
270
|
patch : Union[DatasetPatch, None],
|
|
281
271
|
use_cache : bool,
|
|
282
272
|
subset : Subset,
|
|
283
|
-
num_workers : int,
|
|
284
273
|
batch_size : int,
|
|
285
274
|
validation: Union[float, str, list[int], list[str], None] = None,
|
|
286
275
|
inlineAugmentations: bool = False,
|
|
@@ -293,7 +282,7 @@ class Data(ABC):
|
|
|
293
282
|
self.dataAugmentationsList = dataAugmentationsList
|
|
294
283
|
self.batch_size = batch_size
|
|
295
284
|
self.dataSet_args = dict(groups_src=self.groups_src, inlineAugmentations=inlineAugmentations, dataAugmentationsList = list(self.dataAugmentationsList.values()), use_cache = use_cache, buffer_size=batch_size+1, patch_size=self.patch.patch_size if self.patch is not None else None, overlap=self.patch.overlap if self.patch is not None else None)
|
|
296
|
-
self.dataLoader_args = dict(num_workers=
|
|
285
|
+
self.dataLoader_args = dict(num_workers=int(os.environ["KONFAI_WORKERS"]), pin_memory=True)
|
|
297
286
|
self.data : list[list[dict[str, list[DatasetManager]]], dict[str, list[DatasetManager]]] = []
|
|
298
287
|
self.map : list[list[list[tuple[int, int, int]]], list[tuple[int, int, int]]] = []
|
|
299
288
|
self.datasets: dict[str, Dataset] = {}
|
|
@@ -547,10 +536,9 @@ class DataTrain(Data):
|
|
|
547
536
|
patch : Union[DatasetPatch, None] = DatasetPatch(),
|
|
548
537
|
use_cache : bool = True,
|
|
549
538
|
subset : Union[TrainSubset, dict[str, TrainSubset]] = TrainSubset(),
|
|
550
|
-
num_workers : int = 4,
|
|
551
539
|
batch_size : int = 1,
|
|
552
540
|
validation : Union[float, str, list[int], list[str]] = 0.2) -> None:
|
|
553
|
-
super().__init__(dataset_filenames, groups_src, patch, use_cache, subset,
|
|
541
|
+
super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, batch_size, validation, inlineAugmentations, augmentations if augmentations else {})
|
|
554
542
|
|
|
555
543
|
class DataPrediction(Data):
|
|
556
544
|
|
|
@@ -560,10 +548,9 @@ class DataPrediction(Data):
|
|
|
560
548
|
augmentations : Union[dict[str, DataAugmentationsList], None] = {"DataAugmentation_0" : DataAugmentationsList()},
|
|
561
549
|
patch : Union[DatasetPatch, None] = DatasetPatch(),
|
|
562
550
|
subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
|
|
563
|
-
num_workers : int = 4,
|
|
564
551
|
batch_size : int = 1) -> None:
|
|
565
552
|
|
|
566
|
-
super().__init__(dataset_filenames, groups_src, patch, False, subset,
|
|
553
|
+
super().__init__(dataset_filenames, groups_src, patch, False, subset, batch_size, dataAugmentationsList=augmentations if augmentations else {})
|
|
567
554
|
|
|
568
555
|
class DataMetric(Data):
|
|
569
556
|
|
|
@@ -571,7 +558,6 @@ class DataMetric(Data):
|
|
|
571
558
|
def __init__(self, dataset_filenames : list[str] = ["default:./Dataset"],
|
|
572
559
|
groups_src : dict[str, GroupMetric] = {"default" : GroupMetric()},
|
|
573
560
|
subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
|
|
574
|
-
validation: Union[str, None] = None
|
|
575
|
-
num_workers : int = 4) -> None:
|
|
561
|
+
validation: Union[str, None] = None) -> None:
|
|
576
562
|
|
|
577
|
-
super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset,
|
|
563
|
+
super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, batch_size=1, validation=validation)
|
|
@@ -81,7 +81,6 @@ class PathCombine(ABC):
|
|
|
81
81
|
|
|
82
82
|
class Mean(PathCombine):
|
|
83
83
|
|
|
84
|
-
@config("Mean")
|
|
85
84
|
def __init__(self) -> None:
|
|
86
85
|
super().__init__()
|
|
87
86
|
|
|
@@ -90,7 +89,6 @@ class Mean(PathCombine):
|
|
|
90
89
|
|
|
91
90
|
class Cosinus(PathCombine):
|
|
92
91
|
|
|
93
|
-
@config("Cosinus")
|
|
94
92
|
def __init__(self) -> None:
|
|
95
93
|
super().__init__()
|
|
96
94
|
|
|
@@ -144,7 +142,7 @@ class Accumulator():
|
|
|
144
142
|
|
|
145
143
|
class Patch(ABC):
|
|
146
144
|
|
|
147
|
-
def __init__(self, patch_size: list[int], overlap: Union[int, None],
|
|
145
|
+
def __init__(self, patch_size: list[int], overlap: Union[int, None], padValue: float = 0, extend_slice: int = 0) -> None:
|
|
148
146
|
self.patch_size = patch_size
|
|
149
147
|
self.overlap = overlap
|
|
150
148
|
if isinstance(self.overlap, int):
|
|
@@ -152,15 +150,8 @@ class Patch(ABC):
|
|
|
152
150
|
self.overlap = None
|
|
153
151
|
self._patch_slices : dict[int, list[tuple[slice]]] = {}
|
|
154
152
|
self._nb_patch_per_dim: dict[int, list[tuple[int, bool]]] = {}
|
|
155
|
-
self.path_mask = path_mask
|
|
156
|
-
self.mask = None
|
|
157
153
|
self.padValue = padValue
|
|
158
154
|
self.extend_slice = extend_slice
|
|
159
|
-
if self.path_mask is not None:
|
|
160
|
-
if os.path.exists(self.path_mask):
|
|
161
|
-
self.mask = torch.tensor(sitk.GetArrayFromImage(sitk.ReadImage(self.path_mask)))
|
|
162
|
-
else:
|
|
163
|
-
raise NameError('Mask file not found')
|
|
164
155
|
|
|
165
156
|
def load(self, shape : dict[int, list[int]], a: int = 0) -> None:
|
|
166
157
|
self._patch_slices[a], self._nb_patch_per_dim[a] = get_patch_slices_from_shape(self.patch_size, shape, self.overlap)
|
|
@@ -199,9 +190,6 @@ class Patch(ABC):
|
|
|
199
190
|
padding.append(p)
|
|
200
191
|
|
|
201
192
|
data_sliced = F.pad(data_sliced, tuple(padding), "constant", 0 if data_sliced.dtype == torch.uint8 and self.padValue < 0 else self.padValue)
|
|
202
|
-
if self.mask is not None:
|
|
203
|
-
outside = torch.ones_like(data_sliced)*(0 if data_sliced.dtype == torch.uint8 and self.padValue < 0 else self.padValue)
|
|
204
|
-
data_sliced = torch.where(self.mask == 0, outside, data_sliced)
|
|
205
193
|
|
|
206
194
|
for d in [i for i, v in enumerate(reversed(self.patch_size)) if v == 1]:
|
|
207
195
|
data_sliced = torch.squeeze(data_sliced, dim = len(data_sliced.shape)-d-1)
|
|
@@ -213,14 +201,14 @@ class Patch(ABC):
|
|
|
213
201
|
class DatasetPatch(Patch):
|
|
214
202
|
|
|
215
203
|
@config("Patch")
|
|
216
|
-
def __init__(self, patch_size : list[int] = [128, 128, 128], overlap : Union[int, None] = None,
|
|
217
|
-
super().__init__(patch_size, overlap,
|
|
204
|
+
def __init__(self, patch_size : list[int] = [128, 128, 128], overlap : Union[int, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
|
|
205
|
+
super().__init__(patch_size, overlap, padValue, extend_slice)
|
|
218
206
|
|
|
219
207
|
class ModelPatch(Patch):
|
|
220
208
|
|
|
221
209
|
@config("Patch")
|
|
222
|
-
def __init__(self, patch_size : list[int] = [128, 128, 128], overlap : Union[int, None] = None, patchCombine: Union[str, None] = None,
|
|
223
|
-
super().__init__(patch_size, overlap,
|
|
210
|
+
def __init__(self, patch_size : list[int] = [128, 128, 128], overlap : Union[int, None] = None, patchCombine: Union[str, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
|
|
211
|
+
super().__init__(patch_size, overlap, padValue, extend_slice)
|
|
224
212
|
self.patchCombine = patchCombine
|
|
225
213
|
|
|
226
214
|
def disassemble(self, *dataList: torch.Tensor) -> Iterator[list[torch.Tensor]]:
|
|
@@ -247,7 +235,7 @@ class DatasetManager():
|
|
|
247
235
|
for transformFunction in pre_transforms:
|
|
248
236
|
_shape = transformFunction.transformShape(_shape, cache_attribute)
|
|
249
237
|
|
|
250
|
-
self.patch = DatasetPatch(patch_size=patch.patch_size, overlap=patch.overlap,
|
|
238
|
+
self.patch = DatasetPatch(patch_size=patch.patch_size, overlap=patch.overlap, padValue=patch.padValue, extend_slice=patch.extend_slice) if patch else DatasetPatch(_shape)
|
|
251
239
|
self.patch.load(_shape, 0)
|
|
252
240
|
self.shape = _shape
|
|
253
241
|
self.dataAugmentationsList = dataAugmentationsList
|
|
@@ -314,7 +302,7 @@ class DatasetManager():
|
|
|
314
302
|
dataset.write(self.group_dest, self.name, data.numpy(), self.cache_attributes[0])
|
|
315
303
|
self.data : list[torch.Tensor] = list()
|
|
316
304
|
self.data.append(data)
|
|
317
|
-
|
|
305
|
+
|
|
318
306
|
for i in range(len(self.cache_attributes)-1):
|
|
319
307
|
self.cache_attributes[i+1].update(self.cache_attributes[0])
|
|
320
308
|
self.loaded = True
|
|
@@ -36,7 +36,7 @@ class TransformLoader:
|
|
|
36
36
|
pass
|
|
37
37
|
|
|
38
38
|
def getTransform(self, classpath : str, DL_args : str) -> Transform:
|
|
39
|
-
module, name = _getModule(classpath, "data.transform")
|
|
39
|
+
module, name = _getModule(classpath, "konfai.data.transform")
|
|
40
40
|
return config("{}.{}".format(DL_args, classpath))(getattr(importlib.import_module(module), name))(config = None)
|
|
41
41
|
|
|
42
42
|
class Clip(Transform):
|
|
@@ -213,7 +213,7 @@ class Resample(Transform, ABC):
|
|
|
213
213
|
|
|
214
214
|
class ResampleToResolution(Resample):
|
|
215
215
|
|
|
216
|
-
def __init__(self, spacing : list[
|
|
216
|
+
def __init__(self, spacing : list[float] = [1., 1., 1.]) -> None:
|
|
217
217
|
self.spacing = torch.tensor([0 if s < 0 else s for s in spacing])
|
|
218
218
|
|
|
219
219
|
def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
@@ -244,34 +244,36 @@ class ResampleToResolution(Resample):
|
|
|
244
244
|
cache_attribute["Size"] = np.asarray(size)
|
|
245
245
|
return self._resample(input, size)
|
|
246
246
|
|
|
247
|
-
class
|
|
247
|
+
class ResampleToShape(Resample):
|
|
248
248
|
|
|
249
|
-
def __init__(self,
|
|
250
|
-
self.
|
|
249
|
+
def __init__(self, shape : list[float] = [100,256,256]) -> None:
|
|
250
|
+
self.shape = torch.tensor([0 if s < 0 else s for s in shape])
|
|
251
251
|
|
|
252
252
|
def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
253
|
+
print(shape)
|
|
253
254
|
if "Spacing" not in cache_attribute:
|
|
254
255
|
TransformError("Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
|
|
255
256
|
"Make sure your input is a image (e.g., .nii, .mha) with proper metadata.")
|
|
256
|
-
if len(shape) != len(self.
|
|
257
|
+
if len(shape) != len(self.shape):
|
|
257
258
|
TransformError("Shape and spacing dimensions do not match: shape={shape}, spacing={self.spacing}")
|
|
258
|
-
|
|
259
|
-
for i, s in enumerate(self.
|
|
260
|
-
if s ==
|
|
261
|
-
|
|
262
|
-
|
|
259
|
+
new_shape = self.shape
|
|
260
|
+
for i, s in enumerate(self.shape):
|
|
261
|
+
if s == 0:
|
|
262
|
+
new_shape[i] = shape[i]
|
|
263
|
+
print(new_shape)
|
|
264
|
+
return new_shape
|
|
263
265
|
|
|
264
266
|
def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
for i, s in enumerate(self.
|
|
268
|
-
if s
|
|
269
|
-
|
|
267
|
+
shape = self.shape
|
|
268
|
+
image_shape = torch.tensor([int(x) for x in torch.tensor(input.shape[1:])])
|
|
269
|
+
for i, s in enumerate(self.shape):
|
|
270
|
+
if s == 0:
|
|
271
|
+
shape[i] = image_shape[i]
|
|
270
272
|
if "Spacing" in cache_attribute:
|
|
271
|
-
cache_attribute["Spacing"] = torch.flip(
|
|
272
|
-
cache_attribute["Size"] =
|
|
273
|
-
cache_attribute["Size"] =
|
|
274
|
-
return self._resample(input,
|
|
273
|
+
cache_attribute["Spacing"] = torch.flip(image_shape/shape*torch.flip(cache_attribute.get_tensor("Spacing"), dims=[0]), dims=[0])
|
|
274
|
+
cache_attribute["Size"] = image_shape
|
|
275
|
+
cache_attribute["Size"] = shape
|
|
276
|
+
return self._resample(input, shape)
|
|
275
277
|
|
|
276
278
|
class ResampleTransform(Transform):
|
|
277
279
|
|
|
@@ -410,7 +412,7 @@ class Gradient(Transform):
|
|
|
410
412
|
def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
411
413
|
return input
|
|
412
414
|
|
|
413
|
-
class
|
|
415
|
+
class Argmax(Transform):
|
|
414
416
|
|
|
415
417
|
def __init__(self, dim: int = 0) -> None:
|
|
416
418
|
self.dim = dim
|
|
@@ -27,7 +27,7 @@ class CriterionsLoader():
|
|
|
27
27
|
def getCriterions(self, output_group : str, target_group : str) -> dict[torch.nn.Module, CriterionsAttr]:
|
|
28
28
|
criterions = {}
|
|
29
29
|
for module_classpath, criterionsAttr in self.criterionsLoader.items():
|
|
30
|
-
module, name = _getModule(module_classpath, "metric.measure")
|
|
30
|
+
module, name = _getModule(module_classpath, "konfai.metric.measure")
|
|
31
31
|
criterions[config("{}.metrics.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
|
|
32
32
|
return criterions
|
|
33
33
|
|
|
@@ -9,7 +9,6 @@ from scipy import linalg
|
|
|
9
9
|
|
|
10
10
|
import torch.nn.functional as F
|
|
11
11
|
import os
|
|
12
|
-
import tqdm
|
|
13
12
|
|
|
14
13
|
from typing import Callable, Union
|
|
15
14
|
from functools import partial
|
|
@@ -19,10 +18,11 @@ import copy
|
|
|
19
18
|
from abc import abstractmethod
|
|
20
19
|
|
|
21
20
|
from konfai.utils.config import config
|
|
22
|
-
from konfai.utils.utils import _getModule
|
|
21
|
+
from konfai.utils.utils import _getModule, download_url
|
|
23
22
|
from konfai.data.patching import ModelPatch
|
|
24
23
|
from konfai.network.blocks import LatentDistribution
|
|
25
24
|
from konfai.network.network import ModelLoader, Network
|
|
25
|
+
from tqdm import tqdm
|
|
26
26
|
|
|
27
27
|
modelsRegister = {}
|
|
28
28
|
|
|
@@ -122,7 +122,7 @@ class LPIPS(MaskedLoss):
|
|
|
122
122
|
|
|
123
123
|
patchIterator = datasetPatch.disassemble(LPIPS.normalize(x), LPIPS.normalize(y))
|
|
124
124
|
loss = 0
|
|
125
|
-
with tqdm
|
|
125
|
+
with tqdm(iterable = enumerate(patchIterator), leave=False, total=datasetPatch.getSize(0)) as batch_iter:
|
|
126
126
|
for i, patch_input in batch_iter:
|
|
127
127
|
real, fake = LPIPS.preprocessing(patch_input[0]), LPIPS.preprocessing(patch_input[1])
|
|
128
128
|
loss += loss_fn_alex(real, fake).item()
|
|
@@ -146,9 +146,11 @@ class Dice(Criterion):
|
|
|
146
146
|
target = self.flatten(target)
|
|
147
147
|
return (2.*(input * target).sum() + self.smooth)/(input.sum() + target.sum() + self.smooth)
|
|
148
148
|
|
|
149
|
-
|
|
150
149
|
def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
|
|
151
|
-
target =
|
|
150
|
+
target = F.interpolate(
|
|
151
|
+
targets[0],
|
|
152
|
+
output.shape[2:],
|
|
153
|
+
mode="nearest")
|
|
152
154
|
if output.shape[1] == 1:
|
|
153
155
|
output = F.one_hot(output.type(torch.int64), num_classes=int(torch.max(output).item()+1)).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float()
|
|
154
156
|
target = F.one_hot(target.type(torch.int64), num_classes=output.shape[1]).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float().squeeze(2)
|
|
@@ -242,7 +244,7 @@ class PerceptualLoss(Criterion):
|
|
|
242
244
|
|
|
243
245
|
class Module():
|
|
244
246
|
|
|
245
|
-
@config(
|
|
247
|
+
@config()
|
|
246
248
|
def __init__(self, losses: dict[str, float] = {"Gram": 1, "torch_nn_L1Loss": 1}) -> None:
|
|
247
249
|
self.losses = losses
|
|
248
250
|
self.DL_args = os.environ['KONFAI_CONFIG_PATH'] if "KONFAI_CONFIG_PATH" in os.environ else ""
|
|
@@ -250,7 +252,7 @@ class PerceptualLoss(Criterion):
|
|
|
250
252
|
def getLoss(self) -> dict[torch.nn.Module, float]:
|
|
251
253
|
result: dict[torch.nn.Module, float] = {}
|
|
252
254
|
for loss, l in self.losses.items():
|
|
253
|
-
module, name = _getModule(loss, "metric.measure")
|
|
255
|
+
module, name = _getModule(loss, "konfai.metric.measure")
|
|
254
256
|
result[config(self.DL_args)(getattr(importlib.import_module(module), name))(config=None)] = l
|
|
255
257
|
return result
|
|
256
258
|
|
|
@@ -413,16 +415,19 @@ class FocalLoss(Criterion):
|
|
|
413
415
|
self.gamma = gamma
|
|
414
416
|
self.reduction = reduction
|
|
415
417
|
|
|
416
|
-
def forward(self, output: torch.Tensor, targets: list[torch.Tensor]) -> torch.Tensor:
|
|
417
|
-
target =
|
|
418
|
+
def forward(self, output: torch.Tensor, *targets: list[torch.Tensor]) -> torch.Tensor:
|
|
419
|
+
target = F.interpolate(
|
|
420
|
+
targets[0],
|
|
421
|
+
output.shape[2:],
|
|
422
|
+
mode="nearest").long()
|
|
418
423
|
|
|
419
424
|
logpt = F.log_softmax(output, dim=1)
|
|
420
425
|
pt = torch.exp(logpt)
|
|
421
426
|
|
|
422
|
-
logpt = logpt.gather(1, target
|
|
423
|
-
pt = pt.gather(1, target
|
|
427
|
+
logpt = logpt.gather(1, target)
|
|
428
|
+
pt = pt.gather(1, target)
|
|
424
429
|
|
|
425
|
-
at = self.alpha[target].unsqueeze(1)
|
|
430
|
+
at = self.alpha.to(target.device)[target].unsqueeze(1)
|
|
426
431
|
loss = -at * ((1 - pt) ** self.gamma) * logpt
|
|
427
432
|
|
|
428
433
|
if self.reduction == "mean":
|
|
@@ -458,13 +463,11 @@ class FID(Criterion):
|
|
|
458
463
|
return features
|
|
459
464
|
|
|
460
465
|
def calculate_fid(real_features: np.ndarray, generated_features: np.ndarray) -> float:
|
|
461
|
-
# Calculate mean and covariance statistics
|
|
462
466
|
mu1 = np.mean(real_features, axis=0)
|
|
463
467
|
sigma1 = np.cov(real_features, rowvar=False)
|
|
464
468
|
mu2 = np.mean(generated_features, axis=0)
|
|
465
469
|
sigma2 = np.cov(generated_features, rowvar=False)
|
|
466
470
|
|
|
467
|
-
# Calculate FID score
|
|
468
471
|
diff = mu1 - mu2
|
|
469
472
|
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
|
470
473
|
if np.iscomplexobj(covmean):
|
|
@@ -480,7 +483,6 @@ class FID(Criterion):
|
|
|
480
483
|
generated_features = FID.get_features(generated_images, self.inception_model)
|
|
481
484
|
|
|
482
485
|
return FID.calculate_fid(real_features, generated_features)
|
|
483
|
-
|
|
484
486
|
|
|
485
487
|
class MutualInformationLoss(torch.nn.Module):
|
|
486
488
|
def __init__(self, num_bins: int = 23, sigma_ratio: float = 0.5, smooth_nr: float = 1e-7, smooth_dr: float = 1e-7) -> None:
|
|
@@ -498,7 +500,6 @@ class MutualInformationLoss(torch.nn.Module):
|
|
|
498
500
|
target_weight, target_probability = self.parzen_windowing_gaussian(target)
|
|
499
501
|
return pred_weight, pred_probability, target_weight, target_probability
|
|
500
502
|
|
|
501
|
-
|
|
502
503
|
def parzen_windowing_gaussian(self, img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
503
504
|
img = torch.clamp(img, 0, 1)
|
|
504
505
|
img = img.reshape(img.shape[0], -1, 1) # (batch, num_sample, 1)
|
|
@@ -512,4 +513,72 @@ class MutualInformationLoss(torch.nn.Module):
|
|
|
512
513
|
pab = torch.bmm(wa.permute(0, 2, 1), wb.to(wa)).div(wa.shape[1]) # (batch, num_bins, num_bins)
|
|
513
514
|
papb = torch.bmm(pa.permute(0, 2, 1), pb.to(pa)) # (batch, num_bins, num_bins)
|
|
514
515
|
mi = torch.sum(pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr), dim=(1, 2)) # (batch)
|
|
515
|
-
return torch.mean(mi).neg() # average over the batch and channel ndims
|
|
516
|
+
return torch.mean(mi).neg() # average over the batch and channel ndims
|
|
517
|
+
|
|
518
|
+
class FOCUS(Criterion): #Feature-Oriented Comparison for Unpaired Synthesis
|
|
519
|
+
|
|
520
|
+
def __init__(self, model_name: str, shape: list[int] = [0,0], in_channels: int = 1, losses: dict[str, list[float]] = {"Gram": 1, "torch_nn_L1Loss": 1}) -> None:
|
|
521
|
+
super().__init__()
|
|
522
|
+
if model_name is None:
|
|
523
|
+
return
|
|
524
|
+
self.shape = shape
|
|
525
|
+
self.in_channels = in_channels
|
|
526
|
+
self.losses: dict[torch.nn.Module, list[float]] = {}
|
|
527
|
+
for loss, weights in losses.items():
|
|
528
|
+
module, name = _getModule(loss, "konfai.metric.measure")
|
|
529
|
+
self.losses[config(os.environ['KONFAI_CONFIG_PATH'])(getattr(importlib.import_module(module), name))(config=None)] = weights
|
|
530
|
+
|
|
531
|
+
self.model_path = download_url(model_name, "https://huggingface.co/VBoussot/impact-torchscript-models/resolve/main/")
|
|
532
|
+
self.model: torch.nn.Module = torch.jit.load(self.model_path)
|
|
533
|
+
self.dim = len(self.shape)
|
|
534
|
+
self.shape = self.shape if all([s > 0 for s in self.shape]) else None
|
|
535
|
+
self.modules_loss: dict[str, dict[torch.nn.Module, float]] = {}
|
|
536
|
+
|
|
537
|
+
try:
|
|
538
|
+
dummy_input = torch.zeros((1, self.in_channels, *(self.shape if self.shape else [224] * self.dim)))
|
|
539
|
+
out = self.model(dummy_input)
|
|
540
|
+
if not isinstance(out, (list, tuple)):
|
|
541
|
+
raise TypeError(f"Expected model output to be a list or tuple, but got {type(out)}.")
|
|
542
|
+
if len(self.weight) != len(out):
|
|
543
|
+
raise ValueError(f"Mismatch between number of weights ({len(self.weight)}) and model outputs ({len(out)}).")
|
|
544
|
+
except Exception as e:
|
|
545
|
+
msg = (
|
|
546
|
+
f"[Model Sanity Check Failed]\n"
|
|
547
|
+
f"Input shape attempted: {dummy_input.shape}\n"
|
|
548
|
+
f"Expected output length: {len(self.weight)}\n"
|
|
549
|
+
f"Error: {type(e).__name__}: {e}"
|
|
550
|
+
)
|
|
551
|
+
raise RuntimeError(msg) from e
|
|
552
|
+
self.model = None
|
|
553
|
+
|
|
554
|
+
def preprocessing(self, input: torch.Tensor) -> torch.Tensor:
|
|
555
|
+
if self.shape is not None and not all([input.shape[-i-1] == size for i, size in enumerate(reversed(self.shape[2:]))]):
|
|
556
|
+
input = torch.nn.functional.interpolate(input, mode=self.mode, size=tuple(self.shape), align_corners=False).type(torch.float32)
|
|
557
|
+
if input.shape[1] != self.in_channels:
|
|
558
|
+
input = input.repeat(tuple([1,3] + [1 for _ in range(self.dim)]))
|
|
559
|
+
return input
|
|
560
|
+
|
|
561
|
+
def _compute(self, output: torch.Tensor, targets: list[torch.Tensor]) -> torch.Tensor:
|
|
562
|
+
loss = torch.zeros((1), requires_grad = True).to(output.device, non_blocking=False).type(torch.float32)
|
|
563
|
+
output = self.preprocessing(output)
|
|
564
|
+
targets = [self.preprocessing(target) for target in targets]
|
|
565
|
+
self.model.to(output.device)
|
|
566
|
+
for zipped_output in zip(self.model(output), *[self.model(target) for target in targets]):
|
|
567
|
+
output_features = zipped_output[0]
|
|
568
|
+
targets_features = zipped_output[1:]
|
|
569
|
+
for target_features, (loss_function, weights) in zip(targets_features, self.losses.items()):
|
|
570
|
+
for output_feature, target_feature, weight in zip(output_features ,target_features, weights):
|
|
571
|
+
loss += weights*loss_function(output_feature, target_feature)
|
|
572
|
+
return loss
|
|
573
|
+
|
|
574
|
+
def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
|
|
575
|
+
if self.model is None:
|
|
576
|
+
self.model = torch.jit.load(self.model_path)
|
|
577
|
+
loss = torch.zeros((1), requires_grad = True).to(output.device, non_blocking=False).type(torch.float32)
|
|
578
|
+
if len(output.shape) == 5 and self.dim == 2:
|
|
579
|
+
for i in range(output.shape[2]):
|
|
580
|
+
loss = loss + self._compute(output[:, :, i, ...], [t[:, :, i, ...] for t in targets])
|
|
581
|
+
loss/=output.shape[2]
|
|
582
|
+
else:
|
|
583
|
+
loss = self._compute(output, targets)
|
|
584
|
+
return loss.to(output)
|
|
@@ -2,10 +2,10 @@ import torch
|
|
|
2
2
|
import numpy as np
|
|
3
3
|
from abc import abstractmethod
|
|
4
4
|
from konfai.utils.config import config
|
|
5
|
+
from functools import partial
|
|
5
6
|
|
|
6
7
|
class Scheduler():
|
|
7
8
|
|
|
8
|
-
@config("Scheduler")
|
|
9
9
|
def __init__(self, start_value: float) -> None:
|
|
10
10
|
self.baseValue = float(start_value)
|
|
11
11
|
self.it = 0
|
|
@@ -19,7 +19,6 @@ class Scheduler():
|
|
|
19
19
|
|
|
20
20
|
class Constant(Scheduler):
|
|
21
21
|
|
|
22
|
-
@config("Constant")
|
|
23
22
|
def __init__(self, value: float = 1):
|
|
24
23
|
super().__init__(value)
|
|
25
24
|
|
|
@@ -28,8 +27,7 @@ class Constant(Scheduler):
|
|
|
28
27
|
|
|
29
28
|
class CosineAnnealing(Scheduler):
|
|
30
29
|
|
|
31
|
-
|
|
32
|
-
def __init__(self, start_value: float, eta_min: float = 0.00001, T_max: int = 100):
|
|
30
|
+
def __init__(self, start_value: float = 1, eta_min: float = 0.00001, T_max: int = 100):
|
|
33
31
|
super().__init__(start_value)
|
|
34
32
|
self.eta_min = eta_min
|
|
35
33
|
self.T_max = T_max
|
|
@@ -37,13 +35,10 @@ class CosineAnnealing(Scheduler):
|
|
|
37
35
|
def get_value(self):
|
|
38
36
|
return self.eta_min + (self.baseValue - self.eta_min) *(1 + np.cos(self.it * torch.pi / self.T_max)) / 2
|
|
39
37
|
|
|
40
|
-
class
|
|
38
|
+
class Warmup(torch.optim.lr_scheduler.LambdaLR):
|
|
41
39
|
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
super().__init__(start_value)
|
|
45
|
-
self.eta_min = eta_min
|
|
46
|
-
self.T_max = T_max
|
|
40
|
+
def warmup(warmup_steps: int, step: int) -> float:
|
|
41
|
+
return min(1.0, (step+1) / (warmup_steps+1))
|
|
47
42
|
|
|
48
|
-
def
|
|
49
|
-
|
|
43
|
+
def __init__(self, optimizer: torch.optim.Optimizer, warmup_steps: int = 10, last_epoch=-1, verbose="deprecated"):
|
|
44
|
+
super().__init__(optimizer, partial(Warmup.warmup, warmup_steps), last_epoch, verbose)
|
|
@@ -157,10 +157,9 @@ class Head(network.ModuleArgsDict):
|
|
|
157
157
|
|
|
158
158
|
class ConvNeXt(network.Network):
|
|
159
159
|
|
|
160
|
-
@config("ConvNeXt")
|
|
161
160
|
def __init__( self,
|
|
162
161
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
163
|
-
schedulers
|
|
162
|
+
schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
|
|
164
163
|
outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
|
|
165
164
|
patch : ModelPatch = ModelPatch(),
|
|
166
165
|
dim : int = 3,
|