BETTER-NMA 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. BETTER_NMA/__init__.py +15 -0
  2. BETTER_NMA/adversarial_score.py +19 -0
  3. BETTER_NMA/change_cluster_name.py +0 -0
  4. BETTER_NMA/detect_attack.py +108 -0
  5. BETTER_NMA/find_lca.py +21 -0
  6. BETTER_NMA/main.py +285 -0
  7. BETTER_NMA/nma_creator.py +108 -0
  8. BETTER_NMA/plot.py +131 -0
  9. BETTER_NMA/query_image.py +22 -0
  10. BETTER_NMA/train_adversarial_detector.py +21 -0
  11. BETTER_NMA/utilss/__init__.py +0 -0
  12. BETTER_NMA/utilss/classes/__init__.py +0 -0
  13. BETTER_NMA/utilss/classes/adversarial_dataset.py +61 -0
  14. BETTER_NMA/utilss/classes/adversarial_detector.py +63 -0
  15. BETTER_NMA/utilss/classes/dendrogram.py +131 -0
  16. BETTER_NMA/utilss/classes/edges_dataframe.py +53 -0
  17. BETTER_NMA/utilss/classes/preprocessing/__init__.py +0 -0
  18. BETTER_NMA/utilss/classes/preprocessing/batch_predictor.py +28 -0
  19. BETTER_NMA/utilss/classes/preprocessing/graph_builder.py +46 -0
  20. BETTER_NMA/utilss/classes/preprocessing/heap_processor.py +30 -0
  21. BETTER_NMA/utilss/classes/preprocessing/hierarchical_clustering_builder.py +102 -0
  22. BETTER_NMA/utilss/classes/preprocessing/tree_node.py +71 -0
  23. BETTER_NMA/utilss/classes/preprocessing/z_builder.py +93 -0
  24. BETTER_NMA/utilss/classes/score_calculator.py +165 -0
  25. BETTER_NMA/utilss/classes/whitebox_testing.py +35 -0
  26. BETTER_NMA/utilss/enums/__init__.py +0 -0
  27. BETTER_NMA/utilss/enums/explanation_method.py +6 -0
  28. BETTER_NMA/utilss/enums/heap_types.py +5 -0
  29. BETTER_NMA/utilss/models_utils.py +18 -0
  30. BETTER_NMA/utilss/photos_uitls.py +72 -0
  31. BETTER_NMA/utilss/photos_utils.py +104 -0
  32. BETTER_NMA/utilss/verbal_explanation.py +15 -0
  33. BETTER_NMA/utilss/wordnet_utils.py +177 -0
  34. BETTER_NMA/white_box_testing.py +101 -0
  35. BETTER_NMA-1.0.0.dist-info/METADATA +11 -0
  36. BETTER_NMA-1.0.0.dist-info/RECORD +38 -0
  37. BETTER_NMA-1.0.0.dist-info/WHEEL +5 -0
  38. BETTER_NMA-1.0.0.dist-info/top_level.txt +1 -0
BETTER_NMA/__init__.py ADDED
@@ -0,0 +1,15 @@
1
+ from .main import NMA
2
+ from .white_box_testing import (
3
+ visualize_problematic_images,
4
+ analyze_white_box_results,
5
+ get_white_box_analysis
6
+ )
7
+ from .white_box_testing import save_white_box_results, load_white_box_results
8
+
9
+ __all__ = [
10
+ "NMA",
11
+ "visualize_problematic_images",
12
+ "analyze_white_box_results",
13
+ "get_white_box_analysis"
14
+ ]
15
+ __all__ += ["save_white_box_results", "load_white_box_results"]
@@ -0,0 +1,19 @@
1
+ from .utilss.classes.score_calculator import ScoreCalculator
2
+ from .utilss.photos_utils import preprocess_loaded_image
3
+
4
+ def get_adversarial_score(image, model, full_z, class_names, top_k=5):
5
+ """
6
+ Calculate the adversarial score for a given image using the provided model and score calculator.
7
+
8
+ Parameters:
9
+ - image: Preprocessed image ready for model prediction
10
+ - model: The neural network model used for predictions
11
+ - full_z: The full Z matrix from hierarchical clustering
12
+ - class_names: List of class names corresponding to model outputs
13
+ """
14
+
15
+ score_calculator = ScoreCalculator(full_z, class_names)
16
+ processed_img, _ = preprocess_loaded_image(model, image)
17
+ predictions = model.predict(processed_img, verbose=0)[0]
18
+ score = score_calculator.calculate_adversarial_score(predictions, top_k=top_k)
19
+ return score
File without changes
@@ -0,0 +1,108 @@
1
+ from .utilss.classes.score_calculator import ScoreCalculator
2
+ from .utilss.photos_utils import preprocess_loaded_image
3
+ from .utilss.models_utils import get_top_k_predictions
4
+ import matplotlib.pyplot as plt
5
+
6
+ def plot_detection_result(detection_result, figsize=(12, 8), top_k=5):
7
+ """
8
+ Plot the detection result with image, predictions, and detection status.
9
+
10
+ Parameters:
11
+ - detection_result: Dictionary with keys 'image', 'predictions', 'result', 'probability'
12
+ - figsize: Tuple for figure size
13
+ - top_k: Number of top predictions to display
14
+ """
15
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
16
+
17
+ # Plot the image
18
+ ax1.imshow(detection_result['image'])
19
+ ax1.axis('off')
20
+ ax1.set_title('Input Image', fontsize=14, fontweight='bold')
21
+
22
+ # Plot predictions and detection result
23
+ predictions = detection_result['predictions'][:top_k]
24
+
25
+ # Handle both tuple and dictionary formats
26
+ if predictions and isinstance(predictions[0], tuple):
27
+ # If predictions are tuples (class_name, probability)
28
+ classes = [pred[0] for pred in predictions]
29
+ probabilities = [pred[1] for pred in predictions]
30
+ else:
31
+ # If predictions are dictionaries
32
+ classes = [pred['class'] for pred in predictions]
33
+ probabilities = [pred['probability'] for pred in predictions]
34
+
35
+ # Create bar plot for predictions
36
+ bars = ax2.barh(range(len(classes)), probabilities, color='skyblue')
37
+ ax2.set_yticks(range(len(classes)))
38
+ ax2.set_yticklabels(classes, fontsize=10)
39
+ ax2.set_xlabel('Probability', fontsize=12)
40
+ ax2.set_title(f'Top {top_k} Predictions', fontsize=14, fontweight='bold')
41
+ ax2.set_xlim(0, 1)
42
+
43
+ # Add probability values on bars
44
+ for i, (bar, prob) in enumerate(zip(bars, probabilities)):
45
+ ax2.text(prob + 0.01, i, f'{prob:.3f}',
46
+ va='center', fontsize=9)
47
+
48
+ # Add detection result as text
49
+ detection_status = detection_result['result']
50
+ detection_prob = detection_result['probability']
51
+
52
+ # Color based on detection result
53
+ status_color = 'red' if detection_status == 'Adversarial' else 'green'
54
+
55
+ # Add detection result text box
56
+ textstr = f'Detection: {detection_status}\nConfidence: {detection_prob:.3f}'
57
+ props = dict(boxstyle='round', facecolor=status_color, alpha=0.3)
58
+ ax2.text(0.02, 0.98, textstr, transform=ax2.transAxes, fontsize=12,
59
+ verticalalignment='top', bbox=props, fontweight='bold')
60
+
61
+ plt.tight_layout()
62
+ plt.show()
63
+
64
+ def detect_adversarial_image(model, image, detector, Z_full, labels, plot_result=False):
65
+ """
66
+ Detect if an image is adversarial using the trained logistic regression detector.
67
+
68
+ Parameters:
69
+ - model: The original model being attacked
70
+ - detector: The trained logistic regression detector
71
+ - Z_full: Hierarchical clustering data
72
+ - labels: List of class names
73
+ - image: The input image to check
74
+ - plot_result: Whether to plot the detection result (default: False)
75
+
76
+ Returns:
77
+ - Dictionary with 'image', 'predictions', 'result', and 'probability'
78
+ """
79
+
80
+ image_preprocessed, pil_image = preprocess_loaded_image(model, image)
81
+ score_calculator = ScoreCalculator(Z_full, labels)
82
+
83
+ # Get predictions from the original model
84
+ preds = model.predict(image_preprocessed, verbose=0)
85
+
86
+ # Calculate the adversarial score (or other features)
87
+ score = score_calculator.calculate_adversarial_score(preds[0])
88
+
89
+ # Use the detector to classify the image
90
+ feature = [[score]] # Wrap the score in a 2D array
91
+ label = detector.predict(feature)[0] # Predict the label (0 = clean, 1 = adversarial)
92
+ proba = detector.predict_proba(feature)[0][1] # Probability of being adversarial
93
+ detection_result = 'Adversarial' if label == 1 else 'Clean'
94
+
95
+ # Get top predictions for the image
96
+ image_predictions = get_top_k_predictions(model, image_preprocessed, labels)
97
+
98
+ result = {
99
+ "image": pil_image,
100
+ "predictions": image_predictions,
101
+ "result": detection_result,
102
+ "probability": proba
103
+ }
104
+
105
+ if plot_result:
106
+ plot_detection_result(result)
107
+
108
+ return result
BETTER_NMA/find_lca.py ADDED
@@ -0,0 +1,21 @@
1
+ from .utilss.classes.score_calculator import ScoreCalculator
2
+
3
+ def get_lca(label1, label2, dendrogram, class_names):
4
+ """
5
+ Find the Lowest Common Ancestor (LCA) of two classes in a hierarchical clustering dendrogram.
6
+
7
+ Parameters:
8
+ - label1: Name of the first class
9
+ - label2: Name of the second class
10
+ - full_z: The full Z matrix from hierarchical clustering
11
+ - class_names: List of class names corresponding to model outputs
12
+
13
+ Returns:
14
+ - lca: The name of the lowest common ancestor class
15
+ """
16
+ score_calculator = ScoreCalculator(dendrogram.Z, class_names)
17
+ idx1 = class_names.index(label1)
18
+ idx2 = class_names.index(label2)
19
+ _, lca_idx = score_calculator.count_ancestors_to_lca(idx1, idx2)
20
+ lca = dendrogram.get_node_name(lca_idx)
21
+ return lca
BETTER_NMA/main.py ADDED
@@ -0,0 +1,285 @@
1
+ import tempfile
2
+ import os
3
+ from .nma_creator import preprocessing
4
+ from .plot import plot, plot_sub_dendrogram
5
+ from .train_adversarial_detector import create_logistic_regression_detector
6
+ from .utilss.classes.whitebox_testing import WhiteBoxTesting
7
+ from .detect_attack import detect_adversarial_image
8
+ from .query_image import query_image
9
+ from .utilss.verbal_explanation import get_verbal_explanation
10
+ from .white_box_testing import analyze_white_box_results, get_white_box_analysis
11
+ from .adversarial_score import get_adversarial_score
12
+ from .find_lca import get_lca
13
+
14
+ class NMA:
15
+ def __init__(self, x_train, y_train, labels, model, explanation_method, top_k=4, min_confidence=0.8, infinity=None, threshold=1e-6, save_connections=False, batch_size=32):
16
+ """
17
+ Initializes the NMA object with training data, model, and parameters.
18
+
19
+ Inputs:
20
+ - x_train: Training images (e.g., NumPy array).
21
+ - y_train: Training labels (e.g., list or array).
22
+ - labels: List of class labels (e.g., ['cat', 'dog']).
23
+ - model: Pre-trained model for predictions.
24
+ - explanation_method: Method for generating explanations.
25
+ - top_k: Number of top predictions to consider (default: 4).
26
+ - min_confidence: Minimum confidence threshold (default: 0.8).
27
+ - infinity: Value for infinity in calculations, usually the labels count (default: None).
28
+ - threshold: Threshold for clustering, depends on the model (default: 1e-6).
29
+ - save_connections: Whether to save edges dataframe, use True for white box testing (default: False).
30
+ - batch_size: Batch size for processing (default: 32).
31
+
32
+ Outputs: None (initializes the object).
33
+
34
+ Explanation: Sets up the dendrogram (visual explanation) and edges dataframe using preprocessing.
35
+ """
36
+ self.model = model
37
+ self.explanation_method = explanation_method
38
+ self.top_k = top_k
39
+ self.labels = labels
40
+ self.min_confidence = min_confidence
41
+ self.infinity = infinity
42
+ self.threshold = threshold
43
+ self.save_connections = save_connections
44
+ self.batch_size = batch_size
45
+ self.labels = labels
46
+ self.detector = None
47
+
48
+ self.dendrogram_object, self.edges_df = preprocessing(x_train, y_train, labels, model, explanation_method, top_k, min_confidence, infinity, threshold, save_connections, batch_size)
49
+ print("NMA initialized")
50
+
51
+ ## plot functions: ##
52
+
53
+ def plot(self, sub_labels=None, title="Sub Dendrogram", figsize=(12, 8), **kwargs):
54
+ """
55
+ Plots the dendrogram.
56
+
57
+ Inputs:
58
+ - sub_labels (optional, if not defined, uses all labels): List of labels to highlight (e.g., ['cat', 'dog']).
59
+ - title (optional): Plot title (default: "Sub Dendrogram").
60
+ - figsize (optional): Figure size (default: (12, 8)).
61
+ - **kwargs: Additional arguments for plotting.
62
+
63
+ Outputs: None (displays plot).
64
+
65
+ Explanation: Visualizes the full dendrogram or highlights sub-labels. includes json representation of the dendrogram.
66
+ """
67
+ plot(self, sub_labels, title=title, figsize=figsize, **kwargs)
68
+
69
+ def plot_sub_dendrogram(self, sub_labels, title="Sub Dendrogram", figsize=(12, 8)):
70
+ """
71
+ Plots a sub-dendrogram for specific labels.
72
+
73
+ Inputs:
74
+ - sub_labels: List of labels to include (e.g., ['apple', 'banana']).
75
+ - title (optional): Plot title (default: "Sub Dendrogram").
76
+ - figsize (optional): Figure size (default: (12, 8)).
77
+
78
+ Outputs: None (displays plot).
79
+
80
+ Explanation: Renders a subset of the dendrogram based on provided labels.
81
+ """
82
+ plot_sub_dendrogram(self.dendrogram_object.Z, self.labels, sub_labels, title=title, figsize=figsize)
83
+
84
+ ## white box testing functions: ##
85
+
86
+ def white_box_testing(self, source_labels, target_labels, analyze_results=False, x_train=None, encode_images=True):
87
+ """
88
+ Performs white-box testing to find problematic images.
89
+
90
+ Inputs:
91
+ - source_labels: List of source labels (e.g., ['cat']).
92
+ - target_labels: List of target labels (e.g., ['dog']).
93
+ - analyze_results (optional): Whether to analyze results (default: False).
94
+ - x_train (optional): Training images for analysis.
95
+ - encode_images (optional): Whether to encode images (default: True).
96
+
97
+ Outputs: Dictionary of problematic images or analyzed results.
98
+
99
+ Explanation: Finds images that could be misclassified using edges dataframe.
100
+ """
101
+ if self.edges_df is None:
102
+ raise ValueError("White box testing requires edges_df. Initialize NMA with save_connections=True")
103
+
104
+ whitebox = WhiteBoxTesting(self.model.name if hasattr(self.model, 'name') else "model")
105
+ problematic_imgs_dict = whitebox.find_problematic_images(
106
+ source_labels, target_labels, self.edges_df, self.explanation_method)
107
+
108
+ if analyze_results:
109
+ return analyze_white_box_results(problematic_imgs_dict, x_train, encode_images)
110
+
111
+ return problematic_imgs_dict
112
+
113
+ def get_white_box_analysis(self, source_labels, target_labels, x_train=None):
114
+ if self.edges_df is None:
115
+ raise ValueError("White box testing requires edges_df. Initialize NMA with save_connections=True")
116
+
117
+
118
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f:
119
+ temp_path = f.name
120
+ self.edges_df.to_csv(temp_path, index=False)
121
+
122
+ try:
123
+ results = get_white_box_analysis(
124
+ edges_df_path=temp_path,
125
+ model_filename=self.model.name if hasattr(self.model, 'name') else "model",
126
+ dataset_str=str(self.explanation_method),
127
+ source_labels=source_labels,
128
+ target_labels=target_labels,
129
+ x_train=x_train
130
+ )
131
+ return results
132
+ finally:
133
+ os.unlink(temp_path)
134
+
135
+ ## adversarial detection functions: ##
136
+
137
+ def train_adversarial_detector(self, authentic_images, attacked_images):
138
+ """
139
+ Trains an adversarial detector.
140
+
141
+ Inputs:
142
+ - authentic_images: Array of clean images.
143
+ - attacked_images: Array of adversarial images.
144
+
145
+ Outputs: Trained detector model.
146
+
147
+ Explanation: Trains a logistic regression detector using dendrogram data.
148
+ """
149
+ if self.dendrogram_object is None:
150
+ raise ValueError("NMA must be initialized with dendrogram data for adversarial detection")
151
+
152
+ self.detector = create_logistic_regression_detector(
153
+ self.dendrogram_object.Z,
154
+ self.model,
155
+ authentic_images,
156
+ attacked_images,
157
+ self.labels
158
+ )
159
+ return self.detector
160
+
161
+ def detect_attack(self, image, plot_result=False):
162
+ """
163
+ Detects if an image is adversarial.
164
+
165
+ Inputs:
166
+ - image: Image to analyze.
167
+ - plot_result (optional): Whether to plot results (default: False).
168
+
169
+ Outputs: Detection result (e.g., boolean and scores).
170
+
171
+ Explanation: Uses trained detector to check for adversarial attacks.
172
+ """
173
+ if self.detector is None:
174
+ raise ValueError("Adversarial detector not trained. Call train_adversarial_detector first.")
175
+
176
+ return detect_adversarial_image(
177
+ self.model,
178
+ image,
179
+ self.detector,
180
+ self.dendrogram_object.Z,
181
+ self.labels,
182
+ plot_result=plot_result
183
+ )
184
+
185
+ def find_lca(self, label1, label2):
186
+ """
187
+ Finds the lowest common ancestor of two labels.
188
+
189
+ Inputs:
190
+ - label1: First label (e.g., 'cat').
191
+ - label2: Second label (e.g., 'dog').
192
+
193
+ Outputs: LCA cluster or label.
194
+
195
+ Explanation: Determines the LCA in the dendrogram hierarchy.
196
+ """
197
+ lca = get_lca(label1, label2, self.dendrogram_object, self.labels)
198
+ return lca
199
+
200
+ def adversarial_score(self, image, top_k=5):
201
+ """
202
+ Computes adversarial score for an image.
203
+
204
+ Inputs:
205
+ - image: Image to score.
206
+ - top_k (optional): Number of top predictions (default: 5).
207
+
208
+ Outputs: Adversarial score.
209
+
210
+ Explanation: Calculates score based on predictions and dendrogram.
211
+ """
212
+ score = get_adversarial_score(image, self.model, self.dendrogram_object.Z, self.labels, top_k=top_k)
213
+ return score
214
+
215
+ ## query and explanation functions: ##
216
+
217
+ def query_image(self, image, top_k=5):
218
+ """
219
+ Queries the model for predictions and explanations.
220
+
221
+ Inputs:
222
+ - image: Image to query.
223
+ - top_k (optional): Number of top predictions (default: 5).
224
+
225
+ Outputs: Tuple of predictions and explanations.
226
+
227
+ Explanation: Predicts and explains using dendrogram and verbal explanation of the dendrogram.
228
+ """
229
+ if self.dendrogram_object is None:
230
+ raise ValueError("NMA must be initialized with dendrogram data to query images")
231
+
232
+ if self.labels is None or len(self.labels) == 0:
233
+ raise ValueError("NMA must be initialized with labels to query images")
234
+
235
+ if self.model is None:
236
+ raise ValueError("NMA must be initialized with a model to query images")
237
+
238
+ return query_image(image, self.model, self.labels, self.dendrogram_object, top_k=top_k)
239
+
240
+ def verbal_explanation(self, image):
241
+ """
242
+ Generates a verbal explanation for an image.
243
+
244
+ Inputs:
245
+ - image: Image to explain.
246
+
247
+ Outputs: Verbal explanation.
248
+
249
+ Explanation: Calls query_image and returns the explanation.
250
+ """
251
+ if self.dendrogram_object is None:
252
+ raise ValueError("NMA must be initialized with dendrogram data to query images")
253
+
254
+ if self.labels is None or len(self.labels) == 0:
255
+ raise ValueError("NMA must be initialized with labels to query images")
256
+
257
+ if self.model is None:
258
+ raise ValueError("NMA must be initialized with a model to query images")
259
+
260
+ result = self.query_image(image)
261
+ if result is None:
262
+ return None
263
+ predictions, explanation = result
264
+ return explanation
265
+
266
+ def change_cluster_name(self, cluster_id, new_name):
267
+ """
268
+ Renames a cluster in the dendrogram.
269
+
270
+ Inputs:
271
+ - cluster_id: ID of the cluster to rename.
272
+ - new_name: New name for the cluster.
273
+
274
+ Outputs: None (prints success or raises error).
275
+
276
+ Explanation: Updates the cluster name if valid.
277
+ """
278
+ if self.dendrogram_object is None:
279
+ raise ValueError("NMA must be initialized with dendrogram data to change cluster names")
280
+
281
+ result = self.dendrogram_object.rename_cluster(cluster_id, new_name)
282
+ if not result:
283
+ raise ValueError(f"Failed to rename cluster: {cluster_id}")
284
+
285
+ print(f"Cluster {cluster_id} renamed to {new_name}")
@@ -0,0 +1,108 @@
1
+ import tensorflow as tf
2
+ import pandas as pd
3
+ from igraph import Graph
4
+ from .utilss.enums.explanation_method import ExplanationMethod
5
+ from .utilss.classes.preprocessing.batch_predictor import BatchPredictor
6
+ from .utilss.classes.preprocessing.heap_processor import HeapProcessor
7
+ from .utilss.classes.preprocessing.graph_builder import GraphBuilder
8
+ from .utilss.classes.preprocessing.hierarchical_clustering_builder import HierarchicalClusteringBuilder
9
+ from .utilss.classes.preprocessing.z_builder import ZBuilder
10
+ from .utilss.classes.dendrogram import Dendrogram
11
+
12
+ def preprocessing(x_train, y_train, labels, model, explanation_method, top_k, min_confidence, infinity, threshold, save_connections, batch_size=32):
13
+ try:
14
+ X = x_train
15
+ y = y_train
16
+
17
+ graph = Graph(directed=False)
18
+ graph.add_vertices(labels)
19
+
20
+ edges_data = []
21
+ batch_images = []
22
+ true_labels = []
23
+ original_dataset_positions = []
24
+
25
+ predictor = BatchPredictor(model, batch_size)
26
+ builder = GraphBuilder(explanation_method, infinity)
27
+ count = 0
28
+
29
+ for i, image in enumerate(X):
30
+ source_label = y[i]
31
+ batch_images.append(image)
32
+ true_labels.append(source_label)
33
+ original_dataset_positions.append(i)
34
+
35
+ if len(batch_images) == predictor.batch_size or i == len(X) - 1:
36
+ top_predictions_batch = predictor.get_top_predictions(
37
+ batch_images, labels, top_k, threshold
38
+ )
39
+ added_labels = []
40
+ for j, top_predictions in enumerate(top_predictions_batch):
41
+ current_label = true_labels[j]
42
+ original_index = original_dataset_positions[j]
43
+ seen_labels_for_image = {current_label}
44
+ if len(top_predictions) == 0:
45
+ print("Empty predictions for image", original_index)
46
+ continue
47
+
48
+ if len(top_predictions[0]) < 2:
49
+ print("Malformed predictions for image", original_index)
50
+ continue
51
+
52
+ if top_predictions[0][2] > min_confidence:
53
+ filtered_predictions = top_predictions
54
+
55
+ if filtered_predictions[0][1] != current_label:
56
+ continue
57
+
58
+ if count < 10:
59
+ # print(filtered_predictions)
60
+ count = count + 1
61
+
62
+ for _, pred_label, pred_prob in filtered_predictions:
63
+ if pred_label not in labels:
64
+ raise ValueError(
65
+ f"Prediction label '{pred_label}' not in graph labels."
66
+ )
67
+ # Add to seen labels set for this image
68
+ seen_labels_for_image.add(pred_label)
69
+
70
+ if current_label != pred_label:
71
+ edge_data = builder.update_graph(
72
+ # graph, current_label, pred_label, pred_prob, i, dataset_class
73
+ graph, current_label, pred_label, pred_prob, original_index
74
+ )
75
+ # Only append edge_data if it's not None (not a self-loop)
76
+ if edge_data is not None:
77
+ edges_data.append(edge_data)
78
+ added_labels.append(pred_label)
79
+
80
+ # Now add infinity edges for all labels not seen in THIS image
81
+ if explanation_method == ExplanationMethod.DISSIMILARITY.value:
82
+ for label in labels:
83
+ # if label != current_label:
84
+ if label not in seen_labels_for_image:
85
+ builder.add_infinity_edges(
86
+ graph, added_labels, label, current_label
87
+ )
88
+
89
+ batch_images = []
90
+ true_labels = []
91
+ original_dataset_positions = []
92
+
93
+ edges_df = None
94
+ if save_connections:
95
+ edges_df = pd.DataFrame(edges_data)
96
+
97
+ heap_processor = HeapProcessor(graph, explanation_method, labels)
98
+ clustering = HierarchicalClusteringBuilder(heap_processor, labels)
99
+
100
+ z_builder = ZBuilder()
101
+ z = z_builder.create_z_matrix_from_tree(clustering, labels)
102
+
103
+ dendrogram_object = Dendrogram(z)
104
+ dendrogram = dendrogram_object.build_tree_hierarchy(z, labels)
105
+
106
+ return dendrogram_object, edges_df
107
+ except Exception as e:
108
+ print(f"Error while preprocessing model: {str(e)}")
BETTER_NMA/plot.py ADDED
@@ -0,0 +1,131 @@
1
+ from re import sub
2
+ import matplotlib.pyplot as plt
3
+ import scipy.cluster.hierarchy as sch
4
+ import numpy as np
5
+ from .utilss.classes.dendrogram import Dendrogram
6
+
7
+
8
+ def extract_sub_dendrogram(Z_full, labels, selected_labels):
9
+ """Extract a sub-dendrogram from the full Z matrix for only the selected labels."""
10
+ original_label_to_idx = {name: i for i, name in enumerate(labels)}
11
+
12
+ for label in selected_labels:
13
+ if label not in original_label_to_idx:
14
+ raise ValueError(
15
+ f"Label '{label}' not found in original class names")
16
+
17
+ selected_indices = [original_label_to_idx[label]
18
+ for label in selected_labels]
19
+ Z_sub = []
20
+ selected_indices_set = set(selected_indices)
21
+ cluster_size = {i: 1 for i in range(len(selected_indices))}
22
+
23
+ new_positions = {original: new for new,
24
+ original in enumerate(selected_indices)}
25
+ active_nodes = selected_indices.copy()
26
+ next_id = len(selected_indices)
27
+
28
+ for i, (left, right, height, _) in enumerate(Z_full):
29
+ left, right = int(left), int(right)
30
+ left_in_active = left in active_nodes
31
+ right_in_active = right in active_nodes
32
+
33
+ if left_in_active and right_in_active:
34
+ new_left = new_positions[left]
35
+ new_right = new_positions[right]
36
+ new_size = cluster_size.get(
37
+ new_left, 1) + cluster_size.get(new_right, 1)
38
+
39
+ if new_left > new_right:
40
+ new_left, new_right = new_right, new_left
41
+
42
+ Z_sub.append([new_left, new_right, height, new_size])
43
+
44
+ active_nodes.remove(left)
45
+ active_nodes.remove(right)
46
+ active_nodes.append(len(labels) + i)
47
+ new_positions[len(labels) + i] = next_id
48
+ cluster_size[next_id] = new_size
49
+ next_id += 1
50
+
51
+ elif left_in_active:
52
+ active_nodes.remove(left)
53
+ active_nodes.append(len(labels) + i)
54
+ new_positions[len(labels) + i] = new_positions[left]
55
+
56
+ elif right_in_active:
57
+ active_nodes.remove(right)
58
+ active_nodes.append(len(labels) + i)
59
+ new_positions[len(labels) + i] = new_positions[right]
60
+
61
+ if Z_sub:
62
+ Z_sub = np.array(Z_sub)
63
+ max_idx = len(selected_indices) - 1
64
+ for i, row in enumerate(Z_sub):
65
+ if row[0] > max_idx:
66
+ Z_sub[i, 0] = max_idx
67
+ if row[1] > max_idx:
68
+ Z_sub[i, 1] = max_idx
69
+ max_idx += 1
70
+ else:
71
+ Z_sub = np.empty((0, 4))
72
+
73
+ return Z_sub, selected_labels
74
+
75
+
76
+ def plot_sub_dendrogram(Z, labels, selected_labels, title, figsize):
77
+
78
+ Z_sub, selected_labels = extract_sub_dendrogram(Z, labels, selected_labels)
79
+ """Plot the sub-dendrogram."""
80
+ if len(Z_sub) == 0:
81
+ raise ValueError(
82
+ "No clustering relationships found among selected labels.")
83
+
84
+ plt.figure(figsize=figsize)
85
+ sch.dendrogram(
86
+ Z_sub,
87
+ labels=selected_labels,
88
+ leaf_rotation=0,
89
+ leaf_font_size=10,
90
+ orientation='right',
91
+ )
92
+ plt.title(title)
93
+ plt.xlabel("Distance")
94
+ plt.ylabel("Elements")
95
+ plt.tight_layout()
96
+ plt.show()
97
+
98
+
99
+ def plot(nma_instance, sub_labels, title, figsize, **kwargs):
100
+ if nma_instance.dendrogram_object.Z is None:
101
+ raise ValueError("No linkage matrix (z) found in NMA instance")
102
+
103
+ if sub_labels is None:
104
+ sub_labels = nma_instance.labels
105
+
106
+ filtered_dendrogram_json = nma_instance.dendrogram_object.get_sub_dendrogram_formatted(
107
+ sub_labels)
108
+ print("Filtered dendrogram structure:")
109
+ print(filtered_dendrogram_json)
110
+
111
+ if hasattr(nma_instance, 'labels'):
112
+ plot_sub_dendrogram(nma_instance.dendrogram_object.Z,
113
+ nma_instance.labels, sub_labels, title, figsize)
114
+
115
+ print(nma_instance.dendrogram_object.Z)
116
+
117
+ # plt.figure(figsize=(20, 15))
118
+ # sch.dendrogram(
119
+ # nma_instance.z,
120
+ # leaf_rotation=0,
121
+ # leaf_font_size=10,
122
+ # orientation='right',
123
+ # color_threshold=85,
124
+ # **kwargs
125
+ # )
126
+ # plt.title('NMA Hierarchical Clustering Dendrogram')
127
+ # plt.xlabel('Elements')
128
+ # plt.ylabel('Distance')
129
+ # plt.tight_layout()
130
+ # plt.show()
131
+