konfai 1.2.3__tar.gz → 1.2.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- {konfai-1.2.3 → konfai-1.2.5}/PKG-INFO +1 -1
- {konfai-1.2.3 → konfai-1.2.5}/konfai/data/data_manager.py +2 -3
- {konfai-1.2.3 → konfai-1.2.5}/konfai/data/transform.py +71 -22
- {konfai-1.2.3 → konfai-1.2.5}/konfai/network/network.py +1 -2
- {konfai-1.2.3 → konfai-1.2.5}/konfai/predictor.py +3 -2
- {konfai-1.2.3 → konfai-1.2.5}/konfai/trainer.py +0 -1
- {konfai-1.2.3 → konfai-1.2.5}/konfai.egg-info/PKG-INFO +1 -1
- {konfai-1.2.3 → konfai-1.2.5}/pyproject.toml +1 -1
- {konfai-1.2.3 → konfai-1.2.5}/LICENSE +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/README.md +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai/__init__.py +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai/data/__init__.py +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai/data/augmentation.py +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai/data/patching.py +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai/evaluator.py +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai/main.py +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai/metric/__init__.py +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai/metric/measure.py +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai/metric/schedulers.py +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai/models/classification/convNeXt.py +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai/models/classification/resnet.py +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai/models/generation/cStyleGan.py +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai/models/generation/ddpm.py +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai/models/generation/diffusionGan.py +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai/models/generation/gan.py +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai/models/generation/vae.py +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai/models/registration/registration.py +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai/models/representation/representation.py +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai/models/segmentation/NestedUNet.py +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai/models/segmentation/UNet.py +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai/network/__init__.py +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai/network/blocks.py +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai/utils/ITK.py +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai/utils/__init__.py +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai/utils/config.py +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai/utils/dataset.py +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai/utils/utils.py +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai.egg-info/SOURCES.txt +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai.egg-info/dependency_links.txt +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai.egg-info/entry_points.txt +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai.egg-info/requires.txt +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/konfai.egg-info/top_level.txt +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/setup.cfg +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/tests/test_config.py +0 -0
- {konfai-1.2.3 → konfai-1.2.5}/tests/test_dataset.py +0 -0
|
@@ -317,7 +317,6 @@ class Subset:
|
|
|
317
317
|
else:
|
|
318
318
|
index_set = index_set.intersection(set(self._get_index(s, names)))
|
|
319
319
|
index = list(index_set)
|
|
320
|
-
print(index)
|
|
321
320
|
else:
|
|
322
321
|
index = self._get_index(self.subset, names)
|
|
323
322
|
if self.shuffle:
|
|
@@ -380,8 +379,8 @@ class Data(ABC):
|
|
|
380
379
|
use_cache=use_cache,
|
|
381
380
|
)
|
|
382
381
|
self.dataLoader_args = {
|
|
383
|
-
"num_workers":
|
|
384
|
-
"pin_memory":
|
|
382
|
+
"num_workers": 0,
|
|
383
|
+
"pin_memory": False,
|
|
385
384
|
}
|
|
386
385
|
self.data: list[list[dict[str, list[DatasetManager]]]] = []
|
|
387
386
|
self.mapping: list[list[list[tuple[int, int, int]]]] = []
|
|
@@ -48,12 +48,13 @@ class Clip(Transform):
|
|
|
48
48
|
|
|
49
49
|
def __init__(
|
|
50
50
|
self,
|
|
51
|
-
min_value: float = -1024,
|
|
52
|
-
max_value: float = 1024,
|
|
51
|
+
min_value: float | str = -1024,
|
|
52
|
+
max_value: float | str = 1024,
|
|
53
53
|
save_clip_min: bool = False,
|
|
54
54
|
save_clip_max: bool = False,
|
|
55
|
+
mask: str | None = None
|
|
55
56
|
) -> None:
|
|
56
|
-
if max_value <= min_value:
|
|
57
|
+
if isinstance(min_value, float) and isinstance(max_value, float) and max_value <= min_value:
|
|
57
58
|
raise ValueError(
|
|
58
59
|
f"[Clip] Invalid clipping range: max_value ({max_value}) must be greater than min_value ({min_value})"
|
|
59
60
|
)
|
|
@@ -61,14 +62,56 @@ class Clip(Transform):
|
|
|
61
62
|
self.max_value = max_value
|
|
62
63
|
self.save_clip_min = save_clip_min
|
|
63
64
|
self.save_clip_max = save_clip_max
|
|
65
|
+
self.mask = mask
|
|
64
66
|
|
|
65
67
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
66
|
-
|
|
67
|
-
|
|
68
|
+
mask = None
|
|
69
|
+
if self.mask is not None:
|
|
70
|
+
for dataset in self.datasets:
|
|
71
|
+
if dataset.is_dataset_exist(self.mask, name):
|
|
72
|
+
mask, _ = dataset.read_data(self.mask, name)
|
|
73
|
+
break
|
|
74
|
+
if mask is None and self.mask is not None:
|
|
75
|
+
raise ValueError(f"Requested mask '{self.mask}' is not present in any dataset. Check your dataset group names or configuration.")
|
|
76
|
+
if mask is None:
|
|
77
|
+
tensor_masked = tensor
|
|
78
|
+
else:
|
|
79
|
+
tensor_masked = tensor[mask == 1]
|
|
80
|
+
|
|
81
|
+
if isinstance(self.min_value, str):
|
|
82
|
+
if self.min_value == "min":
|
|
83
|
+
min_value = torch.min(tensor_masked)
|
|
84
|
+
elif self.min_value.startswith("percentile:"):
|
|
85
|
+
try:
|
|
86
|
+
percentile = float(self.min_value.split(":")[1])
|
|
87
|
+
min_value = np.percentile(tensor_masked, percentile)
|
|
88
|
+
except (IndexError, ValueError):
|
|
89
|
+
raise ValueError(f"Invalid format for min_value: '{self.min_value}'. Expected 'percentile:<float>'")
|
|
90
|
+
else:
|
|
91
|
+
raise TypeError(f"Unsupported string for min_value: '{self.min_value}'. Must be a float, 'min', or 'percentile:<float>'.")
|
|
92
|
+
else:
|
|
93
|
+
min_value = self.min_value
|
|
94
|
+
|
|
95
|
+
if isinstance(self.max_value, str):
|
|
96
|
+
if self.max_value == "max":
|
|
97
|
+
max_value = torch.max(tensor_masked)
|
|
98
|
+
elif self.max_value.startswith("percentile:"):
|
|
99
|
+
try:
|
|
100
|
+
percentile = float(self.max_value.split(":")[1])
|
|
101
|
+
max_value = np.percentile(tensor_masked, percentile)
|
|
102
|
+
except (IndexError, ValueError):
|
|
103
|
+
raise ValueError(f"Invalid format for max_value: '{self.max_value}'. Expected 'percentile:<float>'")
|
|
104
|
+
else:
|
|
105
|
+
raise TypeError(f"Unsupported string for max_value: '{self.max_value}'. Must be a float, 'max', or 'percentile:<float>'.")
|
|
106
|
+
else:
|
|
107
|
+
max_value = self.max_value
|
|
108
|
+
|
|
109
|
+
tensor[torch.where(tensor < min_value)] = min_value
|
|
110
|
+
tensor[torch.where(tensor > max_value)] = max_value
|
|
68
111
|
if self.save_clip_min:
|
|
69
|
-
cache_attribute["Min"] =
|
|
112
|
+
cache_attribute["Min"] = min_value
|
|
70
113
|
if self.save_clip_max:
|
|
71
|
-
cache_attribute["Max"] =
|
|
114
|
+
cache_attribute["Max"] = max_value
|
|
72
115
|
return tensor
|
|
73
116
|
|
|
74
117
|
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -143,36 +186,42 @@ class Standardize(Transform):
|
|
|
143
186
|
lazy: bool = False,
|
|
144
187
|
mean: list[float] | None = None,
|
|
145
188
|
std: list[float] | None = None,
|
|
189
|
+
mask: str | None = None
|
|
146
190
|
) -> None:
|
|
147
191
|
self.lazy = lazy
|
|
148
192
|
self.mean = mean
|
|
149
193
|
self.std = std
|
|
194
|
+
self.mask = mask
|
|
150
195
|
|
|
151
196
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
197
|
+
mask = None
|
|
198
|
+
if self.mask is not None:
|
|
199
|
+
for dataset in self.datasets:
|
|
200
|
+
if dataset.is_dataset_exist(self.mask, name):
|
|
201
|
+
mask, _ = dataset.read_data(self.mask, name)
|
|
202
|
+
break
|
|
203
|
+
if mask is None and self.mask is not None:
|
|
204
|
+
raise ValueError(f"Requested mask '{self.mask}' is not present in any dataset. Check your dataset group names or configuration.")
|
|
205
|
+
if mask is None:
|
|
206
|
+
tensor_masked = tensor
|
|
207
|
+
else:
|
|
208
|
+
tensor_masked = tensor[mask == 1]
|
|
209
|
+
|
|
152
210
|
if "Mean" not in cache_attribute:
|
|
153
|
-
cache_attribute["Mean"] = (
|
|
154
|
-
|
|
155
|
-
tensor.type(torch.float32),
|
|
156
|
-
dim=[i + 1 for i in range(len(tensor.shape) - 1)],
|
|
157
|
-
)
|
|
158
|
-
if self.mean is None
|
|
159
|
-
else torch.tensor([self.mean])
|
|
160
|
-
)
|
|
211
|
+
cache_attribute["Mean"] = torch.tensor([torch.mean(tensor_masked.type(torch.float32))]) if self.mean is None else torch.tensor([self.mean])
|
|
212
|
+
|
|
161
213
|
if "Std" not in cache_attribute:
|
|
162
214
|
cache_attribute["Std"] = (
|
|
163
|
-
torch.std(
|
|
164
|
-
|
|
165
|
-
dim=[i + 1 for i in range(len(tensor.shape) - 1)],
|
|
166
|
-
)
|
|
215
|
+
torch.tensor([torch.std(
|
|
216
|
+
tensor_masked.type(torch.float32))])
|
|
167
217
|
if self.std is None
|
|
168
218
|
else torch.tensor([self.std])
|
|
169
219
|
)
|
|
170
|
-
|
|
171
220
|
if self.lazy:
|
|
172
221
|
return tensor
|
|
173
222
|
else:
|
|
174
|
-
mean = cache_attribute.get_tensor("Mean")
|
|
175
|
-
std = cache_attribute.get_tensor("Std")
|
|
223
|
+
mean = cache_attribute.get_tensor("Mean")
|
|
224
|
+
std = cache_attribute.get_tensor("Std")
|
|
176
225
|
return (tensor - mean) / std
|
|
177
226
|
|
|
178
227
|
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -6,8 +6,7 @@ from collections import OrderedDict
|
|
|
6
6
|
from collections.abc import Callable, Iterable, Iterator, Sequence
|
|
7
7
|
from enum import Enum
|
|
8
8
|
from functools import partial
|
|
9
|
-
from typing import Any
|
|
10
|
-
from typing_extensions import Self
|
|
9
|
+
from typing import Any, Self
|
|
11
10
|
|
|
12
11
|
import numpy as np
|
|
13
12
|
import torch
|
|
@@ -407,7 +407,6 @@ class _Predictor:
|
|
|
407
407
|
"""
|
|
408
408
|
|
|
409
409
|
self.model_composite.eval()
|
|
410
|
-
self.model_composite.to(torch.float32)
|
|
411
410
|
self.model_composite.module.set_state(NetState.PREDICTION)
|
|
412
411
|
self.dataloader_prediction.dataset.load("Prediction")
|
|
413
412
|
with tqdm.tqdm(
|
|
@@ -716,7 +715,9 @@ class Predictor(DistributedObject):
|
|
|
716
715
|
path = models_directory() + self.name + "/StateDict/"
|
|
717
716
|
name = sorted(os.listdir(path))[-1]
|
|
718
717
|
if os.path.exists(path + name):
|
|
719
|
-
state_dicts.append(
|
|
718
|
+
state_dicts.append(
|
|
719
|
+
torch.load(path + name, map_location=torch.device("cpu"), weights_only=False) # nosec B614
|
|
720
|
+
) # nosec B614
|
|
720
721
|
else:
|
|
721
722
|
raise Exception(f"Model : {path + name} does not exist !")
|
|
722
723
|
return state_dicts
|
|
@@ -632,7 +632,6 @@ class Trainer(DistributedObject):
|
|
|
632
632
|
world_size (int): Total number of distributed processes.
|
|
633
633
|
"""
|
|
634
634
|
state = State[konfai_state()]
|
|
635
|
-
print(checkpoints_directory() + self.name + "/")
|
|
636
635
|
if state != State.RESUME and os.path.exists(checkpoints_directory() + self.name + "/"):
|
|
637
636
|
if os.environ["KONFAI_OVERWRITE"] != "True":
|
|
638
637
|
accept = input(f"The model {self.name} already exists ! Do you want to overwrite it (yes,no) : ")
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|