BETTER-NMA 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- BETTER_NMA/__init__.py +15 -0
- BETTER_NMA/adversarial_score.py +19 -0
- BETTER_NMA/change_cluster_name.py +0 -0
- BETTER_NMA/detect_attack.py +108 -0
- BETTER_NMA/find_lca.py +21 -0
- BETTER_NMA/main.py +285 -0
- BETTER_NMA/nma_creator.py +108 -0
- BETTER_NMA/plot.py +131 -0
- BETTER_NMA/query_image.py +22 -0
- BETTER_NMA/train_adversarial_detector.py +21 -0
- BETTER_NMA/utilss/__init__.py +0 -0
- BETTER_NMA/utilss/classes/__init__.py +0 -0
- BETTER_NMA/utilss/classes/adversarial_dataset.py +61 -0
- BETTER_NMA/utilss/classes/adversarial_detector.py +63 -0
- BETTER_NMA/utilss/classes/dendrogram.py +131 -0
- BETTER_NMA/utilss/classes/edges_dataframe.py +53 -0
- BETTER_NMA/utilss/classes/preprocessing/__init__.py +0 -0
- BETTER_NMA/utilss/classes/preprocessing/batch_predictor.py +28 -0
- BETTER_NMA/utilss/classes/preprocessing/graph_builder.py +46 -0
- BETTER_NMA/utilss/classes/preprocessing/heap_processor.py +30 -0
- BETTER_NMA/utilss/classes/preprocessing/hierarchical_clustering_builder.py +102 -0
- BETTER_NMA/utilss/classes/preprocessing/tree_node.py +71 -0
- BETTER_NMA/utilss/classes/preprocessing/z_builder.py +93 -0
- BETTER_NMA/utilss/classes/score_calculator.py +165 -0
- BETTER_NMA/utilss/classes/whitebox_testing.py +35 -0
- BETTER_NMA/utilss/enums/__init__.py +0 -0
- BETTER_NMA/utilss/enums/explanation_method.py +6 -0
- BETTER_NMA/utilss/enums/heap_types.py +5 -0
- BETTER_NMA/utilss/models_utils.py +18 -0
- BETTER_NMA/utilss/photos_uitls.py +72 -0
- BETTER_NMA/utilss/photos_utils.py +104 -0
- BETTER_NMA/utilss/verbal_explanation.py +15 -0
- BETTER_NMA/utilss/wordnet_utils.py +177 -0
- BETTER_NMA/white_box_testing.py +101 -0
- BETTER_NMA-1.0.0.dist-info/METADATA +11 -0
- BETTER_NMA-1.0.0.dist-info/RECORD +38 -0
- BETTER_NMA-1.0.0.dist-info/WHEEL +5 -0
- BETTER_NMA-1.0.0.dist-info/top_level.txt +1 -0
BETTER_NMA/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
1
|
+
from .main import NMA
|
2
|
+
from .white_box_testing import (
|
3
|
+
visualize_problematic_images,
|
4
|
+
analyze_white_box_results,
|
5
|
+
get_white_box_analysis
|
6
|
+
)
|
7
|
+
from .white_box_testing import save_white_box_results, load_white_box_results
|
8
|
+
|
9
|
+
__all__ = [
|
10
|
+
"NMA",
|
11
|
+
"visualize_problematic_images",
|
12
|
+
"analyze_white_box_results",
|
13
|
+
"get_white_box_analysis"
|
14
|
+
]
|
15
|
+
__all__ += ["save_white_box_results", "load_white_box_results"]
|
@@ -0,0 +1,19 @@
|
|
1
|
+
from .utilss.classes.score_calculator import ScoreCalculator
|
2
|
+
from .utilss.photos_utils import preprocess_loaded_image
|
3
|
+
|
4
|
+
def get_adversarial_score(image, model, full_z, class_names, top_k=5):
|
5
|
+
"""
|
6
|
+
Calculate the adversarial score for a given image using the provided model and score calculator.
|
7
|
+
|
8
|
+
Parameters:
|
9
|
+
- image: Preprocessed image ready for model prediction
|
10
|
+
- model: The neural network model used for predictions
|
11
|
+
- full_z: The full Z matrix from hierarchical clustering
|
12
|
+
- class_names: List of class names corresponding to model outputs
|
13
|
+
"""
|
14
|
+
|
15
|
+
score_calculator = ScoreCalculator(full_z, class_names)
|
16
|
+
processed_img, _ = preprocess_loaded_image(model, image)
|
17
|
+
predictions = model.predict(processed_img, verbose=0)[0]
|
18
|
+
score = score_calculator.calculate_adversarial_score(predictions, top_k=top_k)
|
19
|
+
return score
|
File without changes
|
@@ -0,0 +1,108 @@
|
|
1
|
+
from .utilss.classes.score_calculator import ScoreCalculator
|
2
|
+
from .utilss.photos_utils import preprocess_loaded_image
|
3
|
+
from .utilss.models_utils import get_top_k_predictions
|
4
|
+
import matplotlib.pyplot as plt
|
5
|
+
|
6
|
+
def plot_detection_result(detection_result, figsize=(12, 8), top_k=5):
|
7
|
+
"""
|
8
|
+
Plot the detection result with image, predictions, and detection status.
|
9
|
+
|
10
|
+
Parameters:
|
11
|
+
- detection_result: Dictionary with keys 'image', 'predictions', 'result', 'probability'
|
12
|
+
- figsize: Tuple for figure size
|
13
|
+
- top_k: Number of top predictions to display
|
14
|
+
"""
|
15
|
+
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
|
16
|
+
|
17
|
+
# Plot the image
|
18
|
+
ax1.imshow(detection_result['image'])
|
19
|
+
ax1.axis('off')
|
20
|
+
ax1.set_title('Input Image', fontsize=14, fontweight='bold')
|
21
|
+
|
22
|
+
# Plot predictions and detection result
|
23
|
+
predictions = detection_result['predictions'][:top_k]
|
24
|
+
|
25
|
+
# Handle both tuple and dictionary formats
|
26
|
+
if predictions and isinstance(predictions[0], tuple):
|
27
|
+
# If predictions are tuples (class_name, probability)
|
28
|
+
classes = [pred[0] for pred in predictions]
|
29
|
+
probabilities = [pred[1] for pred in predictions]
|
30
|
+
else:
|
31
|
+
# If predictions are dictionaries
|
32
|
+
classes = [pred['class'] for pred in predictions]
|
33
|
+
probabilities = [pred['probability'] for pred in predictions]
|
34
|
+
|
35
|
+
# Create bar plot for predictions
|
36
|
+
bars = ax2.barh(range(len(classes)), probabilities, color='skyblue')
|
37
|
+
ax2.set_yticks(range(len(classes)))
|
38
|
+
ax2.set_yticklabels(classes, fontsize=10)
|
39
|
+
ax2.set_xlabel('Probability', fontsize=12)
|
40
|
+
ax2.set_title(f'Top {top_k} Predictions', fontsize=14, fontweight='bold')
|
41
|
+
ax2.set_xlim(0, 1)
|
42
|
+
|
43
|
+
# Add probability values on bars
|
44
|
+
for i, (bar, prob) in enumerate(zip(bars, probabilities)):
|
45
|
+
ax2.text(prob + 0.01, i, f'{prob:.3f}',
|
46
|
+
va='center', fontsize=9)
|
47
|
+
|
48
|
+
# Add detection result as text
|
49
|
+
detection_status = detection_result['result']
|
50
|
+
detection_prob = detection_result['probability']
|
51
|
+
|
52
|
+
# Color based on detection result
|
53
|
+
status_color = 'red' if detection_status == 'Adversarial' else 'green'
|
54
|
+
|
55
|
+
# Add detection result text box
|
56
|
+
textstr = f'Detection: {detection_status}\nConfidence: {detection_prob:.3f}'
|
57
|
+
props = dict(boxstyle='round', facecolor=status_color, alpha=0.3)
|
58
|
+
ax2.text(0.02, 0.98, textstr, transform=ax2.transAxes, fontsize=12,
|
59
|
+
verticalalignment='top', bbox=props, fontweight='bold')
|
60
|
+
|
61
|
+
plt.tight_layout()
|
62
|
+
plt.show()
|
63
|
+
|
64
|
+
def detect_adversarial_image(model, image, detector, Z_full, labels, plot_result=False):
|
65
|
+
"""
|
66
|
+
Detect if an image is adversarial using the trained logistic regression detector.
|
67
|
+
|
68
|
+
Parameters:
|
69
|
+
- model: The original model being attacked
|
70
|
+
- detector: The trained logistic regression detector
|
71
|
+
- Z_full: Hierarchical clustering data
|
72
|
+
- labels: List of class names
|
73
|
+
- image: The input image to check
|
74
|
+
- plot_result: Whether to plot the detection result (default: False)
|
75
|
+
|
76
|
+
Returns:
|
77
|
+
- Dictionary with 'image', 'predictions', 'result', and 'probability'
|
78
|
+
"""
|
79
|
+
|
80
|
+
image_preprocessed, pil_image = preprocess_loaded_image(model, image)
|
81
|
+
score_calculator = ScoreCalculator(Z_full, labels)
|
82
|
+
|
83
|
+
# Get predictions from the original model
|
84
|
+
preds = model.predict(image_preprocessed, verbose=0)
|
85
|
+
|
86
|
+
# Calculate the adversarial score (or other features)
|
87
|
+
score = score_calculator.calculate_adversarial_score(preds[0])
|
88
|
+
|
89
|
+
# Use the detector to classify the image
|
90
|
+
feature = [[score]] # Wrap the score in a 2D array
|
91
|
+
label = detector.predict(feature)[0] # Predict the label (0 = clean, 1 = adversarial)
|
92
|
+
proba = detector.predict_proba(feature)[0][1] # Probability of being adversarial
|
93
|
+
detection_result = 'Adversarial' if label == 1 else 'Clean'
|
94
|
+
|
95
|
+
# Get top predictions for the image
|
96
|
+
image_predictions = get_top_k_predictions(model, image_preprocessed, labels)
|
97
|
+
|
98
|
+
result = {
|
99
|
+
"image": pil_image,
|
100
|
+
"predictions": image_predictions,
|
101
|
+
"result": detection_result,
|
102
|
+
"probability": proba
|
103
|
+
}
|
104
|
+
|
105
|
+
if plot_result:
|
106
|
+
plot_detection_result(result)
|
107
|
+
|
108
|
+
return result
|
BETTER_NMA/find_lca.py
ADDED
@@ -0,0 +1,21 @@
|
|
1
|
+
from .utilss.classes.score_calculator import ScoreCalculator
|
2
|
+
|
3
|
+
def get_lca(label1, label2, dendrogram, class_names):
|
4
|
+
"""
|
5
|
+
Find the Lowest Common Ancestor (LCA) of two classes in a hierarchical clustering dendrogram.
|
6
|
+
|
7
|
+
Parameters:
|
8
|
+
- label1: Name of the first class
|
9
|
+
- label2: Name of the second class
|
10
|
+
- full_z: The full Z matrix from hierarchical clustering
|
11
|
+
- class_names: List of class names corresponding to model outputs
|
12
|
+
|
13
|
+
Returns:
|
14
|
+
- lca: The name of the lowest common ancestor class
|
15
|
+
"""
|
16
|
+
score_calculator = ScoreCalculator(dendrogram.Z, class_names)
|
17
|
+
idx1 = class_names.index(label1)
|
18
|
+
idx2 = class_names.index(label2)
|
19
|
+
_, lca_idx = score_calculator.count_ancestors_to_lca(idx1, idx2)
|
20
|
+
lca = dendrogram.get_node_name(lca_idx)
|
21
|
+
return lca
|
BETTER_NMA/main.py
ADDED
@@ -0,0 +1,285 @@
|
|
1
|
+
import tempfile
|
2
|
+
import os
|
3
|
+
from .nma_creator import preprocessing
|
4
|
+
from .plot import plot, plot_sub_dendrogram
|
5
|
+
from .train_adversarial_detector import create_logistic_regression_detector
|
6
|
+
from .utilss.classes.whitebox_testing import WhiteBoxTesting
|
7
|
+
from .detect_attack import detect_adversarial_image
|
8
|
+
from .query_image import query_image
|
9
|
+
from .utilss.verbal_explanation import get_verbal_explanation
|
10
|
+
from .white_box_testing import analyze_white_box_results, get_white_box_analysis
|
11
|
+
from .adversarial_score import get_adversarial_score
|
12
|
+
from .find_lca import get_lca
|
13
|
+
|
14
|
+
class NMA:
|
15
|
+
def __init__(self, x_train, y_train, labels, model, explanation_method, top_k=4, min_confidence=0.8, infinity=None, threshold=1e-6, save_connections=False, batch_size=32):
|
16
|
+
"""
|
17
|
+
Initializes the NMA object with training data, model, and parameters.
|
18
|
+
|
19
|
+
Inputs:
|
20
|
+
- x_train: Training images (e.g., NumPy array).
|
21
|
+
- y_train: Training labels (e.g., list or array).
|
22
|
+
- labels: List of class labels (e.g., ['cat', 'dog']).
|
23
|
+
- model: Pre-trained model for predictions.
|
24
|
+
- explanation_method: Method for generating explanations.
|
25
|
+
- top_k: Number of top predictions to consider (default: 4).
|
26
|
+
- min_confidence: Minimum confidence threshold (default: 0.8).
|
27
|
+
- infinity: Value for infinity in calculations, usually the labels count (default: None).
|
28
|
+
- threshold: Threshold for clustering, depends on the model (default: 1e-6).
|
29
|
+
- save_connections: Whether to save edges dataframe, use True for white box testing (default: False).
|
30
|
+
- batch_size: Batch size for processing (default: 32).
|
31
|
+
|
32
|
+
Outputs: None (initializes the object).
|
33
|
+
|
34
|
+
Explanation: Sets up the dendrogram (visual explanation) and edges dataframe using preprocessing.
|
35
|
+
"""
|
36
|
+
self.model = model
|
37
|
+
self.explanation_method = explanation_method
|
38
|
+
self.top_k = top_k
|
39
|
+
self.labels = labels
|
40
|
+
self.min_confidence = min_confidence
|
41
|
+
self.infinity = infinity
|
42
|
+
self.threshold = threshold
|
43
|
+
self.save_connections = save_connections
|
44
|
+
self.batch_size = batch_size
|
45
|
+
self.labels = labels
|
46
|
+
self.detector = None
|
47
|
+
|
48
|
+
self.dendrogram_object, self.edges_df = preprocessing(x_train, y_train, labels, model, explanation_method, top_k, min_confidence, infinity, threshold, save_connections, batch_size)
|
49
|
+
print("NMA initialized")
|
50
|
+
|
51
|
+
## plot functions: ##
|
52
|
+
|
53
|
+
def plot(self, sub_labels=None, title="Sub Dendrogram", figsize=(12, 8), **kwargs):
|
54
|
+
"""
|
55
|
+
Plots the dendrogram.
|
56
|
+
|
57
|
+
Inputs:
|
58
|
+
- sub_labels (optional, if not defined, uses all labels): List of labels to highlight (e.g., ['cat', 'dog']).
|
59
|
+
- title (optional): Plot title (default: "Sub Dendrogram").
|
60
|
+
- figsize (optional): Figure size (default: (12, 8)).
|
61
|
+
- **kwargs: Additional arguments for plotting.
|
62
|
+
|
63
|
+
Outputs: None (displays plot).
|
64
|
+
|
65
|
+
Explanation: Visualizes the full dendrogram or highlights sub-labels. includes json representation of the dendrogram.
|
66
|
+
"""
|
67
|
+
plot(self, sub_labels, title=title, figsize=figsize, **kwargs)
|
68
|
+
|
69
|
+
def plot_sub_dendrogram(self, sub_labels, title="Sub Dendrogram", figsize=(12, 8)):
|
70
|
+
"""
|
71
|
+
Plots a sub-dendrogram for specific labels.
|
72
|
+
|
73
|
+
Inputs:
|
74
|
+
- sub_labels: List of labels to include (e.g., ['apple', 'banana']).
|
75
|
+
- title (optional): Plot title (default: "Sub Dendrogram").
|
76
|
+
- figsize (optional): Figure size (default: (12, 8)).
|
77
|
+
|
78
|
+
Outputs: None (displays plot).
|
79
|
+
|
80
|
+
Explanation: Renders a subset of the dendrogram based on provided labels.
|
81
|
+
"""
|
82
|
+
plot_sub_dendrogram(self.dendrogram_object.Z, self.labels, sub_labels, title=title, figsize=figsize)
|
83
|
+
|
84
|
+
## white box testing functions: ##
|
85
|
+
|
86
|
+
def white_box_testing(self, source_labels, target_labels, analyze_results=False, x_train=None, encode_images=True):
|
87
|
+
"""
|
88
|
+
Performs white-box testing to find problematic images.
|
89
|
+
|
90
|
+
Inputs:
|
91
|
+
- source_labels: List of source labels (e.g., ['cat']).
|
92
|
+
- target_labels: List of target labels (e.g., ['dog']).
|
93
|
+
- analyze_results (optional): Whether to analyze results (default: False).
|
94
|
+
- x_train (optional): Training images for analysis.
|
95
|
+
- encode_images (optional): Whether to encode images (default: True).
|
96
|
+
|
97
|
+
Outputs: Dictionary of problematic images or analyzed results.
|
98
|
+
|
99
|
+
Explanation: Finds images that could be misclassified using edges dataframe.
|
100
|
+
"""
|
101
|
+
if self.edges_df is None:
|
102
|
+
raise ValueError("White box testing requires edges_df. Initialize NMA with save_connections=True")
|
103
|
+
|
104
|
+
whitebox = WhiteBoxTesting(self.model.name if hasattr(self.model, 'name') else "model")
|
105
|
+
problematic_imgs_dict = whitebox.find_problematic_images(
|
106
|
+
source_labels, target_labels, self.edges_df, self.explanation_method)
|
107
|
+
|
108
|
+
if analyze_results:
|
109
|
+
return analyze_white_box_results(problematic_imgs_dict, x_train, encode_images)
|
110
|
+
|
111
|
+
return problematic_imgs_dict
|
112
|
+
|
113
|
+
def get_white_box_analysis(self, source_labels, target_labels, x_train=None):
|
114
|
+
if self.edges_df is None:
|
115
|
+
raise ValueError("White box testing requires edges_df. Initialize NMA with save_connections=True")
|
116
|
+
|
117
|
+
|
118
|
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
|
119
|
+
temp_path = f.name
|
120
|
+
self.edges_df.to_csv(temp_path, index=False)
|
121
|
+
|
122
|
+
try:
|
123
|
+
results = get_white_box_analysis(
|
124
|
+
edges_df_path=temp_path,
|
125
|
+
model_filename=self.model.name if hasattr(self.model, 'name') else "model",
|
126
|
+
dataset_str=str(self.explanation_method),
|
127
|
+
source_labels=source_labels,
|
128
|
+
target_labels=target_labels,
|
129
|
+
x_train=x_train
|
130
|
+
)
|
131
|
+
return results
|
132
|
+
finally:
|
133
|
+
os.unlink(temp_path)
|
134
|
+
|
135
|
+
## adversarial detection functions: ##
|
136
|
+
|
137
|
+
def train_adversarial_detector(self, authentic_images, attacked_images):
|
138
|
+
"""
|
139
|
+
Trains an adversarial detector.
|
140
|
+
|
141
|
+
Inputs:
|
142
|
+
- authentic_images: Array of clean images.
|
143
|
+
- attacked_images: Array of adversarial images.
|
144
|
+
|
145
|
+
Outputs: Trained detector model.
|
146
|
+
|
147
|
+
Explanation: Trains a logistic regression detector using dendrogram data.
|
148
|
+
"""
|
149
|
+
if self.dendrogram_object is None:
|
150
|
+
raise ValueError("NMA must be initialized with dendrogram data for adversarial detection")
|
151
|
+
|
152
|
+
self.detector = create_logistic_regression_detector(
|
153
|
+
self.dendrogram_object.Z,
|
154
|
+
self.model,
|
155
|
+
authentic_images,
|
156
|
+
attacked_images,
|
157
|
+
self.labels
|
158
|
+
)
|
159
|
+
return self.detector
|
160
|
+
|
161
|
+
def detect_attack(self, image, plot_result=False):
|
162
|
+
"""
|
163
|
+
Detects if an image is adversarial.
|
164
|
+
|
165
|
+
Inputs:
|
166
|
+
- image: Image to analyze.
|
167
|
+
- plot_result (optional): Whether to plot results (default: False).
|
168
|
+
|
169
|
+
Outputs: Detection result (e.g., boolean and scores).
|
170
|
+
|
171
|
+
Explanation: Uses trained detector to check for adversarial attacks.
|
172
|
+
"""
|
173
|
+
if self.detector is None:
|
174
|
+
raise ValueError("Adversarial detector not trained. Call train_adversarial_detector first.")
|
175
|
+
|
176
|
+
return detect_adversarial_image(
|
177
|
+
self.model,
|
178
|
+
image,
|
179
|
+
self.detector,
|
180
|
+
self.dendrogram_object.Z,
|
181
|
+
self.labels,
|
182
|
+
plot_result=plot_result
|
183
|
+
)
|
184
|
+
|
185
|
+
def find_lca(self, label1, label2):
|
186
|
+
"""
|
187
|
+
Finds the lowest common ancestor of two labels.
|
188
|
+
|
189
|
+
Inputs:
|
190
|
+
- label1: First label (e.g., 'cat').
|
191
|
+
- label2: Second label (e.g., 'dog').
|
192
|
+
|
193
|
+
Outputs: LCA cluster or label.
|
194
|
+
|
195
|
+
Explanation: Determines the LCA in the dendrogram hierarchy.
|
196
|
+
"""
|
197
|
+
lca = get_lca(label1, label2, self.dendrogram_object, self.labels)
|
198
|
+
return lca
|
199
|
+
|
200
|
+
def adversarial_score(self, image, top_k=5):
|
201
|
+
"""
|
202
|
+
Computes adversarial score for an image.
|
203
|
+
|
204
|
+
Inputs:
|
205
|
+
- image: Image to score.
|
206
|
+
- top_k (optional): Number of top predictions (default: 5).
|
207
|
+
|
208
|
+
Outputs: Adversarial score.
|
209
|
+
|
210
|
+
Explanation: Calculates score based on predictions and dendrogram.
|
211
|
+
"""
|
212
|
+
score = get_adversarial_score(image, self.model, self.dendrogram_object.Z, self.labels, top_k=top_k)
|
213
|
+
return score
|
214
|
+
|
215
|
+
## query and explanation functions: ##
|
216
|
+
|
217
|
+
def query_image(self, image, top_k=5):
|
218
|
+
"""
|
219
|
+
Queries the model for predictions and explanations.
|
220
|
+
|
221
|
+
Inputs:
|
222
|
+
- image: Image to query.
|
223
|
+
- top_k (optional): Number of top predictions (default: 5).
|
224
|
+
|
225
|
+
Outputs: Tuple of predictions and explanations.
|
226
|
+
|
227
|
+
Explanation: Predicts and explains using dendrogram and verbal explanation of the dendrogram.
|
228
|
+
"""
|
229
|
+
if self.dendrogram_object is None:
|
230
|
+
raise ValueError("NMA must be initialized with dendrogram data to query images")
|
231
|
+
|
232
|
+
if self.labels is None or len(self.labels) == 0:
|
233
|
+
raise ValueError("NMA must be initialized with labels to query images")
|
234
|
+
|
235
|
+
if self.model is None:
|
236
|
+
raise ValueError("NMA must be initialized with a model to query images")
|
237
|
+
|
238
|
+
return query_image(image, self.model, self.labels, self.dendrogram_object, top_k=top_k)
|
239
|
+
|
240
|
+
def verbal_explanation(self, image):
|
241
|
+
"""
|
242
|
+
Generates a verbal explanation for an image.
|
243
|
+
|
244
|
+
Inputs:
|
245
|
+
- image: Image to explain.
|
246
|
+
|
247
|
+
Outputs: Verbal explanation.
|
248
|
+
|
249
|
+
Explanation: Calls query_image and returns the explanation.
|
250
|
+
"""
|
251
|
+
if self.dendrogram_object is None:
|
252
|
+
raise ValueError("NMA must be initialized with dendrogram data to query images")
|
253
|
+
|
254
|
+
if self.labels is None or len(self.labels) == 0:
|
255
|
+
raise ValueError("NMA must be initialized with labels to query images")
|
256
|
+
|
257
|
+
if self.model is None:
|
258
|
+
raise ValueError("NMA must be initialized with a model to query images")
|
259
|
+
|
260
|
+
result = self.query_image(image)
|
261
|
+
if result is None:
|
262
|
+
return None
|
263
|
+
predictions, explanation = result
|
264
|
+
return explanation
|
265
|
+
|
266
|
+
def change_cluster_name(self, cluster_id, new_name):
|
267
|
+
"""
|
268
|
+
Renames a cluster in the dendrogram.
|
269
|
+
|
270
|
+
Inputs:
|
271
|
+
- cluster_id: ID of the cluster to rename.
|
272
|
+
- new_name: New name for the cluster.
|
273
|
+
|
274
|
+
Outputs: None (prints success or raises error).
|
275
|
+
|
276
|
+
Explanation: Updates the cluster name if valid.
|
277
|
+
"""
|
278
|
+
if self.dendrogram_object is None:
|
279
|
+
raise ValueError("NMA must be initialized with dendrogram data to change cluster names")
|
280
|
+
|
281
|
+
result = self.dendrogram_object.rename_cluster(cluster_id, new_name)
|
282
|
+
if not result:
|
283
|
+
raise ValueError(f"Failed to rename cluster: {cluster_id}")
|
284
|
+
|
285
|
+
print(f"Cluster {cluster_id} renamed to {new_name}")
|
@@ -0,0 +1,108 @@
|
|
1
|
+
import tensorflow as tf
|
2
|
+
import pandas as pd
|
3
|
+
from igraph import Graph
|
4
|
+
from .utilss.enums.explanation_method import ExplanationMethod
|
5
|
+
from .utilss.classes.preprocessing.batch_predictor import BatchPredictor
|
6
|
+
from .utilss.classes.preprocessing.heap_processor import HeapProcessor
|
7
|
+
from .utilss.classes.preprocessing.graph_builder import GraphBuilder
|
8
|
+
from .utilss.classes.preprocessing.hierarchical_clustering_builder import HierarchicalClusteringBuilder
|
9
|
+
from .utilss.classes.preprocessing.z_builder import ZBuilder
|
10
|
+
from .utilss.classes.dendrogram import Dendrogram
|
11
|
+
|
12
|
+
def preprocessing(x_train, y_train, labels, model, explanation_method, top_k, min_confidence, infinity, threshold, save_connections, batch_size=32):
|
13
|
+
try:
|
14
|
+
X = x_train
|
15
|
+
y = y_train
|
16
|
+
|
17
|
+
graph = Graph(directed=False)
|
18
|
+
graph.add_vertices(labels)
|
19
|
+
|
20
|
+
edges_data = []
|
21
|
+
batch_images = []
|
22
|
+
true_labels = []
|
23
|
+
original_dataset_positions = []
|
24
|
+
|
25
|
+
predictor = BatchPredictor(model, batch_size)
|
26
|
+
builder = GraphBuilder(explanation_method, infinity)
|
27
|
+
count = 0
|
28
|
+
|
29
|
+
for i, image in enumerate(X):
|
30
|
+
source_label = y[i]
|
31
|
+
batch_images.append(image)
|
32
|
+
true_labels.append(source_label)
|
33
|
+
original_dataset_positions.append(i)
|
34
|
+
|
35
|
+
if len(batch_images) == predictor.batch_size or i == len(X) - 1:
|
36
|
+
top_predictions_batch = predictor.get_top_predictions(
|
37
|
+
batch_images, labels, top_k, threshold
|
38
|
+
)
|
39
|
+
added_labels = []
|
40
|
+
for j, top_predictions in enumerate(top_predictions_batch):
|
41
|
+
current_label = true_labels[j]
|
42
|
+
original_index = original_dataset_positions[j]
|
43
|
+
seen_labels_for_image = {current_label}
|
44
|
+
if len(top_predictions) == 0:
|
45
|
+
print("Empty predictions for image", original_index)
|
46
|
+
continue
|
47
|
+
|
48
|
+
if len(top_predictions[0]) < 2:
|
49
|
+
print("Malformed predictions for image", original_index)
|
50
|
+
continue
|
51
|
+
|
52
|
+
if top_predictions[0][2] > min_confidence:
|
53
|
+
filtered_predictions = top_predictions
|
54
|
+
|
55
|
+
if filtered_predictions[0][1] != current_label:
|
56
|
+
continue
|
57
|
+
|
58
|
+
if count < 10:
|
59
|
+
# print(filtered_predictions)
|
60
|
+
count = count + 1
|
61
|
+
|
62
|
+
for _, pred_label, pred_prob in filtered_predictions:
|
63
|
+
if pred_label not in labels:
|
64
|
+
raise ValueError(
|
65
|
+
f"Prediction label '{pred_label}' not in graph labels."
|
66
|
+
)
|
67
|
+
# Add to seen labels set for this image
|
68
|
+
seen_labels_for_image.add(pred_label)
|
69
|
+
|
70
|
+
if current_label != pred_label:
|
71
|
+
edge_data = builder.update_graph(
|
72
|
+
# graph, current_label, pred_label, pred_prob, i, dataset_class
|
73
|
+
graph, current_label, pred_label, pred_prob, original_index
|
74
|
+
)
|
75
|
+
# Only append edge_data if it's not None (not a self-loop)
|
76
|
+
if edge_data is not None:
|
77
|
+
edges_data.append(edge_data)
|
78
|
+
added_labels.append(pred_label)
|
79
|
+
|
80
|
+
# Now add infinity edges for all labels not seen in THIS image
|
81
|
+
if explanation_method == ExplanationMethod.DISSIMILARITY.value:
|
82
|
+
for label in labels:
|
83
|
+
# if label != current_label:
|
84
|
+
if label not in seen_labels_for_image:
|
85
|
+
builder.add_infinity_edges(
|
86
|
+
graph, added_labels, label, current_label
|
87
|
+
)
|
88
|
+
|
89
|
+
batch_images = []
|
90
|
+
true_labels = []
|
91
|
+
original_dataset_positions = []
|
92
|
+
|
93
|
+
edges_df = None
|
94
|
+
if save_connections:
|
95
|
+
edges_df = pd.DataFrame(edges_data)
|
96
|
+
|
97
|
+
heap_processor = HeapProcessor(graph, explanation_method, labels)
|
98
|
+
clustering = HierarchicalClusteringBuilder(heap_processor, labels)
|
99
|
+
|
100
|
+
z_builder = ZBuilder()
|
101
|
+
z = z_builder.create_z_matrix_from_tree(clustering, labels)
|
102
|
+
|
103
|
+
dendrogram_object = Dendrogram(z)
|
104
|
+
dendrogram = dendrogram_object.build_tree_hierarchy(z, labels)
|
105
|
+
|
106
|
+
return dendrogram_object, edges_df
|
107
|
+
except Exception as e:
|
108
|
+
print(f"Error while preprocessing model: {str(e)}")
|
BETTER_NMA/plot.py
ADDED
@@ -0,0 +1,131 @@
|
|
1
|
+
from re import sub
|
2
|
+
import matplotlib.pyplot as plt
|
3
|
+
import scipy.cluster.hierarchy as sch
|
4
|
+
import numpy as np
|
5
|
+
from .utilss.classes.dendrogram import Dendrogram
|
6
|
+
|
7
|
+
|
8
|
+
def extract_sub_dendrogram(Z_full, labels, selected_labels):
|
9
|
+
"""Extract a sub-dendrogram from the full Z matrix for only the selected labels."""
|
10
|
+
original_label_to_idx = {name: i for i, name in enumerate(labels)}
|
11
|
+
|
12
|
+
for label in selected_labels:
|
13
|
+
if label not in original_label_to_idx:
|
14
|
+
raise ValueError(
|
15
|
+
f"Label '{label}' not found in original class names")
|
16
|
+
|
17
|
+
selected_indices = [original_label_to_idx[label]
|
18
|
+
for label in selected_labels]
|
19
|
+
Z_sub = []
|
20
|
+
selected_indices_set = set(selected_indices)
|
21
|
+
cluster_size = {i: 1 for i in range(len(selected_indices))}
|
22
|
+
|
23
|
+
new_positions = {original: new for new,
|
24
|
+
original in enumerate(selected_indices)}
|
25
|
+
active_nodes = selected_indices.copy()
|
26
|
+
next_id = len(selected_indices)
|
27
|
+
|
28
|
+
for i, (left, right, height, _) in enumerate(Z_full):
|
29
|
+
left, right = int(left), int(right)
|
30
|
+
left_in_active = left in active_nodes
|
31
|
+
right_in_active = right in active_nodes
|
32
|
+
|
33
|
+
if left_in_active and right_in_active:
|
34
|
+
new_left = new_positions[left]
|
35
|
+
new_right = new_positions[right]
|
36
|
+
new_size = cluster_size.get(
|
37
|
+
new_left, 1) + cluster_size.get(new_right, 1)
|
38
|
+
|
39
|
+
if new_left > new_right:
|
40
|
+
new_left, new_right = new_right, new_left
|
41
|
+
|
42
|
+
Z_sub.append([new_left, new_right, height, new_size])
|
43
|
+
|
44
|
+
active_nodes.remove(left)
|
45
|
+
active_nodes.remove(right)
|
46
|
+
active_nodes.append(len(labels) + i)
|
47
|
+
new_positions[len(labels) + i] = next_id
|
48
|
+
cluster_size[next_id] = new_size
|
49
|
+
next_id += 1
|
50
|
+
|
51
|
+
elif left_in_active:
|
52
|
+
active_nodes.remove(left)
|
53
|
+
active_nodes.append(len(labels) + i)
|
54
|
+
new_positions[len(labels) + i] = new_positions[left]
|
55
|
+
|
56
|
+
elif right_in_active:
|
57
|
+
active_nodes.remove(right)
|
58
|
+
active_nodes.append(len(labels) + i)
|
59
|
+
new_positions[len(labels) + i] = new_positions[right]
|
60
|
+
|
61
|
+
if Z_sub:
|
62
|
+
Z_sub = np.array(Z_sub)
|
63
|
+
max_idx = len(selected_indices) - 1
|
64
|
+
for i, row in enumerate(Z_sub):
|
65
|
+
if row[0] > max_idx:
|
66
|
+
Z_sub[i, 0] = max_idx
|
67
|
+
if row[1] > max_idx:
|
68
|
+
Z_sub[i, 1] = max_idx
|
69
|
+
max_idx += 1
|
70
|
+
else:
|
71
|
+
Z_sub = np.empty((0, 4))
|
72
|
+
|
73
|
+
return Z_sub, selected_labels
|
74
|
+
|
75
|
+
|
76
|
+
def plot_sub_dendrogram(Z, labels, selected_labels, title, figsize):
|
77
|
+
|
78
|
+
Z_sub, selected_labels = extract_sub_dendrogram(Z, labels, selected_labels)
|
79
|
+
"""Plot the sub-dendrogram."""
|
80
|
+
if len(Z_sub) == 0:
|
81
|
+
raise ValueError(
|
82
|
+
"No clustering relationships found among selected labels.")
|
83
|
+
|
84
|
+
plt.figure(figsize=figsize)
|
85
|
+
sch.dendrogram(
|
86
|
+
Z_sub,
|
87
|
+
labels=selected_labels,
|
88
|
+
leaf_rotation=0,
|
89
|
+
leaf_font_size=10,
|
90
|
+
orientation='right',
|
91
|
+
)
|
92
|
+
plt.title(title)
|
93
|
+
plt.xlabel("Distance")
|
94
|
+
plt.ylabel("Elements")
|
95
|
+
plt.tight_layout()
|
96
|
+
plt.show()
|
97
|
+
|
98
|
+
|
99
|
+
def plot(nma_instance, sub_labels, title, figsize, **kwargs):
|
100
|
+
if nma_instance.dendrogram_object.Z is None:
|
101
|
+
raise ValueError("No linkage matrix (z) found in NMA instance")
|
102
|
+
|
103
|
+
if sub_labels is None:
|
104
|
+
sub_labels = nma_instance.labels
|
105
|
+
|
106
|
+
filtered_dendrogram_json = nma_instance.dendrogram_object.get_sub_dendrogram_formatted(
|
107
|
+
sub_labels)
|
108
|
+
print("Filtered dendrogram structure:")
|
109
|
+
print(filtered_dendrogram_json)
|
110
|
+
|
111
|
+
if hasattr(nma_instance, 'labels'):
|
112
|
+
plot_sub_dendrogram(nma_instance.dendrogram_object.Z,
|
113
|
+
nma_instance.labels, sub_labels, title, figsize)
|
114
|
+
|
115
|
+
print(nma_instance.dendrogram_object.Z)
|
116
|
+
|
117
|
+
# plt.figure(figsize=(20, 15))
|
118
|
+
# sch.dendrogram(
|
119
|
+
# nma_instance.z,
|
120
|
+
# leaf_rotation=0,
|
121
|
+
# leaf_font_size=10,
|
122
|
+
# orientation='right',
|
123
|
+
# color_threshold=85,
|
124
|
+
# **kwargs
|
125
|
+
# )
|
126
|
+
# plt.title('NMA Hierarchical Clustering Dendrogram')
|
127
|
+
# plt.xlabel('Elements')
|
128
|
+
# plt.ylabel('Distance')
|
129
|
+
# plt.tight_layout()
|
130
|
+
# plt.show()
|
131
|
+
|