konfai 1.2.4__py3-none-any.whl → 1.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

konfai/data/transform.py CHANGED
@@ -48,12 +48,13 @@ class Clip(Transform):
48
48
 
49
49
  def __init__(
50
50
  self,
51
- min_value: float = -1024,
52
- max_value: float = 1024,
51
+ min_value: float | str = -1024,
52
+ max_value: float | str = 1024,
53
53
  save_clip_min: bool = False,
54
54
  save_clip_max: bool = False,
55
+ mask: str | None = None
55
56
  ) -> None:
56
- if max_value <= min_value:
57
+ if isinstance(min_value, float) and isinstance(max_value, float) and max_value <= min_value:
57
58
  raise ValueError(
58
59
  f"[Clip] Invalid clipping range: max_value ({max_value}) must be greater than min_value ({min_value})"
59
60
  )
@@ -61,14 +62,56 @@ class Clip(Transform):
61
62
  self.max_value = max_value
62
63
  self.save_clip_min = save_clip_min
63
64
  self.save_clip_max = save_clip_max
65
+ self.mask = mask
64
66
 
65
67
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
66
- tensor[torch.where(tensor < self.min_value)] = self.min_value
67
- tensor[torch.where(tensor > self.max_value)] = self.max_value
68
+ mask = None
69
+ if self.mask is not None:
70
+ for dataset in self.datasets:
71
+ if dataset.is_dataset_exist(self.mask, name):
72
+ mask, _ = dataset.read_data(self.mask, name)
73
+ break
74
+ if mask is None and self.mask is not None:
75
+ raise ValueError(f"Requested mask '{self.mask}' is not present in any dataset. Check your dataset group names or configuration.")
76
+ if mask is None:
77
+ tensor_masked = tensor
78
+ else:
79
+ tensor_masked = tensor[mask == 1]
80
+
81
+ if isinstance(self.min_value, str):
82
+ if self.min_value == "min":
83
+ min_value = torch.min(tensor_masked)
84
+ elif self.min_value.startswith("percentile:"):
85
+ try:
86
+ percentile = float(self.min_value.split(":")[1])
87
+ min_value = np.percentile(tensor_masked, percentile)
88
+ except (IndexError, ValueError):
89
+ raise ValueError(f"Invalid format for min_value: '{self.min_value}'. Expected 'percentile:<float>'")
90
+ else:
91
+ raise TypeError(f"Unsupported string for min_value: '{self.min_value}'. Must be a float, 'min', or 'percentile:<float>'.")
92
+ else:
93
+ min_value = self.min_value
94
+
95
+ if isinstance(self.max_value, str):
96
+ if self.max_value == "max":
97
+ max_value = torch.max(tensor_masked)
98
+ elif self.max_value.startswith("percentile:"):
99
+ try:
100
+ percentile = float(self.max_value.split(":")[1])
101
+ max_value = np.percentile(tensor_masked, percentile)
102
+ except (IndexError, ValueError):
103
+ raise ValueError(f"Invalid format for max_value: '{self.max_value}'. Expected 'percentile:<float>'")
104
+ else:
105
+ raise TypeError(f"Unsupported string for max_value: '{self.max_value}'. Must be a float, 'max', or 'percentile:<float>'.")
106
+ else:
107
+ max_value = self.max_value
108
+
109
+ tensor[torch.where(tensor < min_value)] = min_value
110
+ tensor[torch.where(tensor > max_value)] = max_value
68
111
  if self.save_clip_min:
69
- cache_attribute["Min"] = self.min_value
112
+ cache_attribute["Min"] = min_value
70
113
  if self.save_clip_max:
71
- cache_attribute["Max"] = self.max_value
114
+ cache_attribute["Max"] = max_value
72
115
  return tensor
73
116
 
74
117
  def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
@@ -143,36 +186,42 @@ class Standardize(Transform):
143
186
  lazy: bool = False,
144
187
  mean: list[float] | None = None,
145
188
  std: list[float] | None = None,
189
+ mask: str | None = None
146
190
  ) -> None:
147
191
  self.lazy = lazy
148
192
  self.mean = mean
149
193
  self.std = std
194
+ self.mask = mask
150
195
 
151
196
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
197
+ mask = None
198
+ if self.mask is not None:
199
+ for dataset in self.datasets:
200
+ if dataset.is_dataset_exist(self.mask, name):
201
+ mask, _ = dataset.read_data(self.mask, name)
202
+ break
203
+ if mask is None and self.mask is not None:
204
+ raise ValueError(f"Requested mask '{self.mask}' is not present in any dataset. Check your dataset group names or configuration.")
205
+ if mask is None:
206
+ tensor_masked = tensor
207
+ else:
208
+ tensor_masked = tensor[mask == 1]
209
+
152
210
  if "Mean" not in cache_attribute:
153
- cache_attribute["Mean"] = (
154
- torch.mean(
155
- tensor.type(torch.float32),
156
- dim=[i + 1 for i in range(len(tensor.shape) - 1)],
157
- )
158
- if self.mean is None
159
- else torch.tensor([self.mean])
160
- )
211
+ cache_attribute["Mean"] = torch.tensor([torch.mean(tensor_masked.type(torch.float32))]) if self.mean is None else torch.tensor([self.mean])
212
+
161
213
  if "Std" not in cache_attribute:
162
214
  cache_attribute["Std"] = (
163
- torch.std(
164
- tensor.type(torch.float32),
165
- dim=[i + 1 for i in range(len(tensor.shape) - 1)],
166
- )
215
+ torch.tensor([torch.std(
216
+ tensor_masked.type(torch.float32))])
167
217
  if self.std is None
168
218
  else torch.tensor([self.std])
169
219
  )
170
-
171
220
  if self.lazy:
172
221
  return tensor
173
222
  else:
174
- mean = cache_attribute.get_tensor("Mean").view(-1, *[1 for _ in range(len(tensor.shape) - 1)])
175
- std = cache_attribute.get_tensor("Std").view(-1, *[1 for _ in range(len(tensor.shape) - 1)])
223
+ mean = cache_attribute.get_tensor("Mean")
224
+ std = cache_attribute.get_tensor("Std")
176
225
  return (tensor - mean) / std
177
226
 
178
227
  def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
konfai/network/network.py CHANGED
@@ -6,8 +6,7 @@ from collections import OrderedDict
6
6
  from collections.abc import Callable, Iterable, Iterator, Sequence
7
7
  from enum import Enum
8
8
  from functools import partial
9
- from typing import Any
10
- from typing_extensions import Self
9
+ from typing import Any, Self
11
10
 
12
11
  import numpy as np
13
12
  import torch
konfai/predictor.py CHANGED
@@ -715,7 +715,9 @@ class Predictor(DistributedObject):
715
715
  path = models_directory() + self.name + "/StateDict/"
716
716
  name = sorted(os.listdir(path))[-1]
717
717
  if os.path.exists(path + name):
718
- state_dicts.append(torch.load(path + name, map_location=torch.device('cpu'), weights_only=False))
718
+ state_dicts.append(
719
+ torch.load(path + name, map_location=torch.device("cpu"), weights_only=False) # nosec B614
720
+ ) # nosec B614
719
721
  else:
720
722
  raise Exception(f"Model : {path + name} does not exist !")
721
723
  return state_dicts
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: konfai
3
- Version: 1.2.4
3
+ Version: 1.2.5
4
4
  Summary: Modular and configurable Deep Learning framework with YAML and PyTorch
5
5
  Author-email: Valentin Boussot <boussot.v@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -1,13 +1,13 @@
1
1
  konfai/__init__.py,sha256=qjE9Rqxo1sMrkqGS8I5xlGQMZnjIfU-CGgSI5Wmbmbs,1231
2
2
  konfai/evaluator.py,sha256=xAKWUDvdSxqYRUsKqH6ieQF06LWa785aE4zLv4I3_i4,17850
3
3
  konfai/main.py,sha256=Fc4HcJEhPmgunj_f-QYyvQNvjHrKHSUv27Okgu6V5_A,3842
4
- konfai/predictor.py,sha256=S-KlITnyxcAMHmd7aLsoKlJSiZym63yBsJd404HWp9o,34632
4
+ konfai/predictor.py,sha256=TTMB-PowpLef-iLOaQv0X5AXV0q8n7XBEXCS0nJJBTg,34706
5
5
  konfai/trainer.py,sha256=g_TkPDUjToFGDGB7aaRZMn-fQllHV_I2GHFKUzDGF8o,27106
6
6
  konfai/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
7
  konfai/data/augmentation.py,sha256=vcJE7mosvUkwbpbTN_lGP0S1uJrJYGjlLrt9VnDdJYY,27792
8
8
  konfai/data/data_manager.py,sha256=I5YG2HIRWQtP_6Lu9KsEWEJzAFFMTMWOYv7dzlcdUN4,31100
9
9
  konfai/data/patching.py,sha256=jS35OxnJagKNUnJu7TzuGZpVj9fP-6H4nc2OEYOGgt8,16494
10
- konfai/data/transform.py,sha256=YCldsqTTBFFCqc_VdvyuNVs2kmV56CxQBN5XhEoPxho,27745
10
+ konfai/data/transform.py,sha256=_Mih41rpITQo19Y-fFsGdIaztL6CESrATrGQanu3uqU,30210
11
11
  konfai/metric/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  konfai/metric/measure.py,sha256=0mOIZKTa2u0UECpoDSbdJUhttAw_e1BlsROQQpi1oBk,27804
13
13
  konfai/metric/schedulers.py,sha256=TpYMA24FMpxRnqfhMGb0i_Mm-bzT9kySbBgvkYk-6wM,1327
@@ -24,15 +24,15 @@ konfai/models/segmentation/NestedUNet.py,sha256=W4uauwF0HY8ybi49iYiTlKLdJEyD7SaC
24
24
  konfai/models/segmentation/UNet.py,sha256=Pu_LiQdO4Mrzyn0HRE6rwxUjHGH4OG-JpzWB_U1K46g,5602
25
25
  konfai/network/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
26
  konfai/network/blocks.py,sha256=l70_oOcz5Hmol2xmxruG0kke_2SVgO3rXYXVTMSdAS8,15645
27
- konfai/network/network.py,sha256=0G3goIopB_r4PHj4ohOkbov344mVYxZjq5PH57QETmc,54829
27
+ konfai/network/network.py,sha256=PWI6W4sz7G5Pbb-79l7mL61AoyHWlQytRtBzvwh3Ro0,54800
28
28
  konfai/utils/ITK.py,sha256=HVed4Z96X1jTaWrrQNdoBMqOtVK9InAPlDBJu-5uv3g,15476
29
29
  konfai/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
30
  konfai/utils/config.py,sha256=a7t44CYMUT5oCDdjL94IswhCVfFbQ5FCgDWZktDDkc4,14347
31
31
  konfai/utils/dataset.py,sha256=Au22fcADKyDJMfS8Z9q8kEXLtKkoufJsH7Pwly6pALo,28288
32
32
  konfai/utils/utils.py,sha256=jCj3tZ8agQYceSY_tlVYp88UFPE5oUn6tXrqnZGrKiI,28410
33
- konfai-1.2.4.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
34
- konfai-1.2.4.dist-info/METADATA,sha256=y505AGDJqf-U3tO8uj9jQ4bo3g_fqdROVA-T4JZ96d8,2451
35
- konfai-1.2.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
36
- konfai-1.2.4.dist-info/entry_points.txt,sha256=fG82HRN5-g39ACSOCtij_I3N6EHxfYnMR0D7TI_8pW8,81
37
- konfai-1.2.4.dist-info/top_level.txt,sha256=xF470dkIlFoFqTZEOlRehKJr4WU_8OKGXrJqYm9vWKs,7
38
- konfai-1.2.4.dist-info/RECORD,,
33
+ konfai-1.2.5.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
34
+ konfai-1.2.5.dist-info/METADATA,sha256=IuUBngJ6PCmGk3uSVBpGBfOYfptSfgkjLkGpLEwbb4w,2451
35
+ konfai-1.2.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
36
+ konfai-1.2.5.dist-info/entry_points.txt,sha256=fG82HRN5-g39ACSOCtij_I3N6EHxfYnMR0D7TI_8pW8,81
37
+ konfai-1.2.5.dist-info/top_level.txt,sha256=xF470dkIlFoFqTZEOlRehKJr4WU_8OKGXrJqYm9vWKs,7
38
+ konfai-1.2.5.dist-info/RECORD,,
File without changes