Moral88 0.10.0__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- Moral88/segmentation.py +166 -0
- Moral88/utils.py +15 -0
- {Moral88-0.10.0.dist-info → Moral88-0.11.0.dist-info}/METADATA +1 -1
- Moral88-0.11.0.dist-info/RECORD +11 -0
- tests/test_regression.py +1 -0
- Moral88-0.10.0.dist-info/RECORD +0 -10
- {Moral88-0.10.0.dist-info → Moral88-0.11.0.dist-info}/LICENSE +0 -0
- {Moral88-0.10.0.dist-info → Moral88-0.11.0.dist-info}/WHEEL +0 -0
- {Moral88-0.10.0.dist-info → Moral88-0.11.0.dist-info}/top_level.txt +0 -0
Moral88/segmentation.py
ADDED
@@ -0,0 +1,166 @@
|
|
1
|
+
import numpy as np
|
2
|
+
from Moral88.utils import DataValidator
|
3
|
+
from scipy.spatial.distance import directed_hausdorff
|
4
|
+
from sklearn.metrics import f1_score as sklearn_f1_score
|
5
|
+
|
6
|
+
validator = DataValidator()
|
7
|
+
|
8
|
+
def intersection_over_union(y_true, y_pred, num_classes, library=None, flatten=True):
|
9
|
+
"""
|
10
|
+
Computes Intersection over Union (IoU).
|
11
|
+
"""
|
12
|
+
y_true, y_pred = validator.validate_segmentation_inputs(y_true, y_pred)
|
13
|
+
validator.validate_classes(y_true, num_classes)
|
14
|
+
validator.validate_classes(y_pred, num_classes)
|
15
|
+
|
16
|
+
if flatten:
|
17
|
+
y_true = y_true.ravel()
|
18
|
+
y_pred = y_pred.ravel()
|
19
|
+
|
20
|
+
if library == 'Moral88' or library is None:
|
21
|
+
iou_per_class = []
|
22
|
+
for cls in range(num_classes):
|
23
|
+
intersection = np.logical_and(y_true == cls, y_pred == cls).sum()
|
24
|
+
union = np.logical_or(y_true == cls, y_pred == cls).sum()
|
25
|
+
iou = intersection / union if union > 0 else 0
|
26
|
+
iou_per_class.append(iou)
|
27
|
+
|
28
|
+
mean_iou = np.mean(iou_per_class)
|
29
|
+
return iou_per_class, mean_iou
|
30
|
+
|
31
|
+
elif library == 'torch':
|
32
|
+
import torch
|
33
|
+
y_true_tensor = torch.tensor(y_true, dtype=torch.float32)
|
34
|
+
y_pred_tensor = torch.tensor(y_pred, dtype=torch.float32)
|
35
|
+
iou = torch.mean((y_true_tensor * y_pred_tensor).sum(dim=1) / (y_true_tensor + y_pred_tensor - y_true_tensor * y_pred_tensor).sum(dim=1))
|
36
|
+
return iou.item()
|
37
|
+
|
38
|
+
elif library == 'tensorflow':
|
39
|
+
import tensorflow as tf
|
40
|
+
intersection = tf.reduce_sum(tf.cast(y_true == y_pred, tf.float32))
|
41
|
+
union = tf.reduce_sum(tf.cast(y_true | y_pred, tf.float32))
|
42
|
+
iou = intersection / union if union > 0 else 0
|
43
|
+
return iou.numpy()
|
44
|
+
|
45
|
+
raise ValueError("Unsupported library for IoU.")
|
46
|
+
|
47
|
+
def dice_coefficient(y_true, y_pred, num_classes, library=None, flatten=True):
|
48
|
+
"""
|
49
|
+
Computes Dice Coefficient.
|
50
|
+
"""
|
51
|
+
y_true, y_pred = validator.validate_segmentation_inputs(y_true, y_pred)
|
52
|
+
validator.validate_classes(y_true, num_classes)
|
53
|
+
validator.validate_classes(y_pred, num_classes)
|
54
|
+
|
55
|
+
if flatten:
|
56
|
+
y_true = y_true.ravel()
|
57
|
+
y_pred = y_pred.ravel()
|
58
|
+
|
59
|
+
if library == 'Moral88' or library is None:
|
60
|
+
dice_per_class = []
|
61
|
+
for cls in range(num_classes):
|
62
|
+
intersection = np.logical_and(y_true == cls, y_pred == cls).sum()
|
63
|
+
total = (y_true == cls).sum() + (y_pred == cls).sum()
|
64
|
+
dice = (2 * intersection) / total if total > 0 else 0
|
65
|
+
dice_per_class.append(dice)
|
66
|
+
|
67
|
+
mean_dice = np.mean(dice_per_class)
|
68
|
+
return dice_per_class, mean_dice
|
69
|
+
|
70
|
+
elif library == 'torch':
|
71
|
+
import torch
|
72
|
+
y_true_tensor = torch.tensor(y_true, dtype=torch.float32)
|
73
|
+
y_pred_tensor = torch.tensor(y_pred, dtype=torch.float32)
|
74
|
+
intersection = torch.sum(y_true_tensor * y_pred_tensor)
|
75
|
+
total = torch.sum(y_true_tensor) + torch.sum(y_pred_tensor)
|
76
|
+
dice = (2 * intersection) / total if total > 0 else 0
|
77
|
+
return dice.item()
|
78
|
+
|
79
|
+
elif library == 'tensorflow':
|
80
|
+
import tensorflow as tf
|
81
|
+
y_true_tensor = tf.convert_to_tensor(y_true, dtype=tf.float32)
|
82
|
+
y_pred_tensor = tf.convert_to_tensor(y_pred, dtype=tf.float32)
|
83
|
+
intersection = tf.reduce_sum(y_true_tensor * y_pred_tensor)
|
84
|
+
total = tf.reduce_sum(y_true_tensor) + tf.reduce_sum(y_pred_tensor)
|
85
|
+
dice = (2 * intersection) / total if total > 0 else 0
|
86
|
+
return dice.numpy()
|
87
|
+
|
88
|
+
raise ValueError("Unsupported library for Dice Coefficient.")
|
89
|
+
|
90
|
+
def pixel_accuracy(y_true, y_pred, library=None):
|
91
|
+
"""
|
92
|
+
Computes Pixel Accuracy.
|
93
|
+
"""
|
94
|
+
y_true, y_pred = validator.validate_segmentation_inputs(y_true, y_pred)
|
95
|
+
|
96
|
+
if library == 'Moral88' or library is None:
|
97
|
+
correct = (y_true == y_pred).sum()
|
98
|
+
total = y_true.size
|
99
|
+
accuracy = correct / total
|
100
|
+
return accuracy
|
101
|
+
|
102
|
+
elif library == 'torch':
|
103
|
+
import torch
|
104
|
+
y_true_tensor = torch.tensor(y_true, dtype=torch.float32)
|
105
|
+
y_pred_tensor = torch.tensor(y_pred, dtype=torch.float32)
|
106
|
+
correct = torch.sum(y_true_tensor == y_pred_tensor)
|
107
|
+
total = torch.numel(y_true_tensor)
|
108
|
+
accuracy = correct / total
|
109
|
+
return accuracy.item()
|
110
|
+
|
111
|
+
elif library == 'tensorflow':
|
112
|
+
import tensorflow as tf
|
113
|
+
y_true_tensor = tf.convert_to_tensor(y_true, dtype=tf.float32)
|
114
|
+
y_pred_tensor = tf.convert_to_tensor(y_pred, dtype=tf.float32)
|
115
|
+
correct = tf.reduce_sum(tf.cast(y_true_tensor == y_pred_tensor, tf.float32))
|
116
|
+
total = tf.size(y_true_tensor, out_type=tf.float32)
|
117
|
+
accuracy = correct / total
|
118
|
+
return accuracy.numpy()
|
119
|
+
|
120
|
+
raise ValueError("Unsupported library for Pixel Accuracy.")
|
121
|
+
|
122
|
+
def hausdorff_distance(y_true, y_pred, library=None):
|
123
|
+
"""
|
124
|
+
Computes Hausdorff Distance.
|
125
|
+
"""
|
126
|
+
y_true, y_pred = validator.validate_segmentation_inputs(y_true, y_pred)
|
127
|
+
|
128
|
+
if library == 'Moral88' or library is None:
|
129
|
+
y_true_points = np.argwhere(y_true > 0)
|
130
|
+
y_pred_points = np.argwhere(y_pred > 0)
|
131
|
+
|
132
|
+
distance = max(directed_hausdorff(y_true_points, y_pred_points)[0],
|
133
|
+
directed_hausdorff(y_pred_points, y_true_points)[0])
|
134
|
+
return distance
|
135
|
+
|
136
|
+
raise ValueError("Unsupported library for Hausdorff Distance.")
|
137
|
+
|
138
|
+
def f1_score(y_true, y_pred, num_classes, library=None, flatten=True):
|
139
|
+
"""
|
140
|
+
Computes F1 Score.
|
141
|
+
"""
|
142
|
+
y_true, y_pred = validator.validate_segmentation_inputs(y_true, y_pred)
|
143
|
+
validator.validate_classes(y_true, num_classes)
|
144
|
+
validator.validate_classes(y_pred, num_classes)
|
145
|
+
|
146
|
+
if flatten:
|
147
|
+
y_true = y_true.ravel()
|
148
|
+
y_pred = y_pred.ravel()
|
149
|
+
|
150
|
+
if library == 'sklearn':
|
151
|
+
return sklearn_f1_score(y_true, y_pred, average='macro')
|
152
|
+
|
153
|
+
if library == 'Moral88' or library is None:
|
154
|
+
f1_per_class = []
|
155
|
+
for cls in range(num_classes):
|
156
|
+
tp = np.logical_and(y_pred == cls, y_true == cls).sum()
|
157
|
+
fp = np.logical_and(y_pred == cls, y_true != cls).sum()
|
158
|
+
fn = np.logical_and(y_pred != cls, y_true == cls).sum()
|
159
|
+
|
160
|
+
f1 = (2 * tp) / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else 0
|
161
|
+
f1_per_class.append(f1)
|
162
|
+
|
163
|
+
mean_f1 = np.mean(f1_per_class)
|
164
|
+
return f1_per_class, mean_f1
|
165
|
+
|
166
|
+
raise ValueError("Unsupported library for F1 Score.")
|
Moral88/utils.py
CHANGED
@@ -57,6 +57,21 @@ class DataValidator:
|
|
57
57
|
|
58
58
|
return y_true, y_pred
|
59
59
|
|
60
|
+
def validate_segmentation_inputs(self, y_true, y_pred):
|
61
|
+
"""
|
62
|
+
Ensures segmentation inputs are valid, checking dimensions and consistency.
|
63
|
+
"""
|
64
|
+
y_true = np.asarray(y_true, dtype=np.int32)
|
65
|
+
y_pred = np.asarray(y_pred, dtype=np.int32)
|
66
|
+
|
67
|
+
if y_true.shape != y_pred.shape:
|
68
|
+
raise ValueError(f"Shapes of y_true {y_true.shape} and y_pred {y_pred.shape} do not match.")
|
69
|
+
|
70
|
+
if y_true.ndim < 2 or y_pred.ndim < 2:
|
71
|
+
raise ValueError("Segmentation inputs must have at least two dimensions.")
|
72
|
+
|
73
|
+
return y_true, y_pred
|
74
|
+
|
60
75
|
def check_array(self, array, ensure_2d: bool = True, dtype=np.float64, allow_nan: bool = False):
|
61
76
|
"""
|
62
77
|
Validates input array and converts it to specified dtype.
|
@@ -0,0 +1,11 @@
|
|
1
|
+
Moral88/__init__.py,sha256=Z7iEZUqslxRyJU2to6iX6a5Ak1XBZxU3VT4RvOCjsEU,196
|
2
|
+
Moral88/regression.py,sha256=WjNMpX0t99KGTrUKMBFg6LccnPvlnWKnjimu65BLrkc,12061
|
3
|
+
Moral88/segmentation.py,sha256=N6Pg-220JfTxV01Bwjir8kOFMW3nfI0L_MM9zGrvDvg,6470
|
4
|
+
Moral88/utils.py,sha256=4dZ165tRtwCEyz-wESp26-cZp-5Pz8HkSrmNPrKEH38,4534
|
5
|
+
tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
|
+
tests/test_regression.py,sha256=j-XG9j5D24OaJBJT3ROnlvrMDrSR2sH4182U5L3PA5k,3124
|
7
|
+
Moral88-0.11.0.dist-info/LICENSE,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
|
+
Moral88-0.11.0.dist-info/METADATA,sha256=I9J_QFmYuvB-Q5YufPayenl5E4-WvCQznIMtoQlNC24,408
|
9
|
+
Moral88-0.11.0.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
|
10
|
+
Moral88-0.11.0.dist-info/top_level.txt,sha256=gg4pKIcQal4JhJAb77H5W6SHC77e-BeLTy4hxfXwmfw,14
|
11
|
+
Moral88-0.11.0.dist-info/RECORD,,
|
tests/test_regression.py
CHANGED
Moral88-0.10.0.dist-info/RECORD
DELETED
@@ -1,10 +0,0 @@
|
|
1
|
-
Moral88/__init__.py,sha256=Z7iEZUqslxRyJU2to6iX6a5Ak1XBZxU3VT4RvOCjsEU,196
|
2
|
-
Moral88/regression.py,sha256=WjNMpX0t99KGTrUKMBFg6LccnPvlnWKnjimu65BLrkc,12061
|
3
|
-
Moral88/utils.py,sha256=ggiiY5Vp6A6MbGtghftkM0MJM0R9hhR2avUbpV43_yk,3933
|
4
|
-
tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
|
-
tests/test_regression.py,sha256=w5A6eGTmVuh-eN0nTACPoQzzrX2wI5McyQuMyCvf07M,3122
|
6
|
-
Moral88-0.10.0.dist-info/LICENSE,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
|
-
Moral88-0.10.0.dist-info/METADATA,sha256=6YVHD8ZRgbJ-4lTVmQnJS7caxgRvHGirWgNLHDtSNPw,408
|
8
|
-
Moral88-0.10.0.dist-info/WHEEL,sha256=pkctZYzUS4AYVn6dJ-7367OJZivF2e8RA9b_ZBjif18,92
|
9
|
-
Moral88-0.10.0.dist-info/top_level.txt,sha256=gg4pKIcQal4JhJAb77H5W6SHC77e-BeLTy4hxfXwmfw,14
|
10
|
-
Moral88-0.10.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|