BETTER-NMA 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. BETTER_NMA/__init__.py +15 -0
  2. BETTER_NMA/adversarial_score.py +19 -0
  3. BETTER_NMA/change_cluster_name.py +0 -0
  4. BETTER_NMA/detect_attack.py +108 -0
  5. BETTER_NMA/find_lca.py +21 -0
  6. BETTER_NMA/main.py +285 -0
  7. BETTER_NMA/nma_creator.py +108 -0
  8. BETTER_NMA/plot.py +131 -0
  9. BETTER_NMA/query_image.py +22 -0
  10. BETTER_NMA/train_adversarial_detector.py +21 -0
  11. BETTER_NMA/utilss/__init__.py +0 -0
  12. BETTER_NMA/utilss/classes/__init__.py +0 -0
  13. BETTER_NMA/utilss/classes/adversarial_dataset.py +61 -0
  14. BETTER_NMA/utilss/classes/adversarial_detector.py +63 -0
  15. BETTER_NMA/utilss/classes/dendrogram.py +131 -0
  16. BETTER_NMA/utilss/classes/edges_dataframe.py +53 -0
  17. BETTER_NMA/utilss/classes/preprocessing/__init__.py +0 -0
  18. BETTER_NMA/utilss/classes/preprocessing/batch_predictor.py +28 -0
  19. BETTER_NMA/utilss/classes/preprocessing/graph_builder.py +46 -0
  20. BETTER_NMA/utilss/classes/preprocessing/heap_processor.py +30 -0
  21. BETTER_NMA/utilss/classes/preprocessing/hierarchical_clustering_builder.py +102 -0
  22. BETTER_NMA/utilss/classes/preprocessing/tree_node.py +71 -0
  23. BETTER_NMA/utilss/classes/preprocessing/z_builder.py +93 -0
  24. BETTER_NMA/utilss/classes/score_calculator.py +165 -0
  25. BETTER_NMA/utilss/classes/whitebox_testing.py +35 -0
  26. BETTER_NMA/utilss/enums/__init__.py +0 -0
  27. BETTER_NMA/utilss/enums/explanation_method.py +6 -0
  28. BETTER_NMA/utilss/enums/heap_types.py +5 -0
  29. BETTER_NMA/utilss/models_utils.py +18 -0
  30. BETTER_NMA/utilss/photos_uitls.py +72 -0
  31. BETTER_NMA/utilss/photos_utils.py +104 -0
  32. BETTER_NMA/utilss/verbal_explanation.py +15 -0
  33. BETTER_NMA/utilss/wordnet_utils.py +177 -0
  34. BETTER_NMA/white_box_testing.py +101 -0
  35. BETTER_NMA-1.0.0.dist-info/METADATA +11 -0
  36. BETTER_NMA-1.0.0.dist-info/RECORD +38 -0
  37. BETTER_NMA-1.0.0.dist-info/WHEEL +5 -0
  38. BETTER_NMA-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,93 @@
1
+ import numpy as np
2
+
3
+ class ZBuilder:
4
+ def create_z_matrix_from_tree(self, clustering_builder, labels):
5
+ unique_labels = []
6
+ seen_labels = set()
7
+
8
+ for label in labels:
9
+ if label not in seen_labels:
10
+ unique_labels.append(label)
11
+ seen_labels.add(label)
12
+
13
+ print(f"Total labels: {len(labels)}, Unique labels: {len(unique_labels)}")
14
+ label_to_idx = {label: i for i, label in enumerate(unique_labels)}
15
+
16
+ n = len(unique_labels)
17
+ z_matrix = np.zeros((n - 1, 4), dtype=np.float64)
18
+ if not clustering_builder.forest:
19
+ print("ERROR: No tree found in the forest!")
20
+ return z_matrix, unique_labels
21
+
22
+ root = clustering_builder.forest[0]
23
+ processed = {}
24
+ next_z_idx = n
25
+ row_idx = 0
26
+
27
+ def process_node(node):
28
+ nonlocal row_idx, next_z_idx
29
+ if node.node_id in processed:
30
+ return processed[node.node_id]
31
+
32
+ if len(node) == 0:
33
+ if hasattr(node, "node_name") and node.node_name in label_to_idx:
34
+ idx = label_to_idx[node.node_name]
35
+ processed[node.node_id] = idx
36
+ return idx
37
+ else:
38
+ return None
39
+ left_idx = None
40
+ right_idx = None
41
+
42
+ if len(node) >= 1:
43
+ left_idx = process_node(node[0])
44
+ if len(node) >= 2:
45
+ right_idx = process_node(node[1])
46
+ if left_idx is None or right_idx is None:
47
+ return None
48
+ if left_idx > right_idx:
49
+ left_idx, right_idx = right_idx, left_idx
50
+ left_count = 1 if left_idx < n else z_matrix[left_idx - n][3]
51
+ right_count = 1 if right_idx < n else z_matrix[right_idx - n][3]
52
+
53
+ if row_idx < n - 1:
54
+ z_matrix[row_idx] = [
55
+ left_idx,
56
+ right_idx,
57
+ node.weight,
58
+ left_count + right_count,
59
+ ]
60
+ this_idx = next_z_idx
61
+ processed[node.node_id] = this_idx
62
+ next_z_idx += 1
63
+ row_idx += 1
64
+ return this_idx
65
+
66
+ return None
67
+
68
+ process_node(root)
69
+ if row_idx < n - 1:
70
+ print(f"WARNING: Only filled {row_idx} of {n-1} rows in Z matrix!")
71
+ # Trim the matrix to the actual number of rows we filled
72
+ z_matrix = z_matrix[:row_idx]
73
+
74
+ for i in range(z_matrix.shape[0]):
75
+ left_idx = int(z_matrix[i, 0])
76
+ right_idx = int(z_matrix[i, 1])
77
+ left_size = 1 if left_idx < n else z_matrix[left_idx - n, 3]
78
+ right_size = 1 if right_idx < n else z_matrix[right_idx - n, 3]
79
+ expected_size = left_size + right_size
80
+ if z_matrix[i, 3] != expected_size:
81
+ print(
82
+ f"Fixing cluster size at row {i}: {z_matrix[i, 3]} → {expected_size}"
83
+ )
84
+ z_matrix[i, 3] = expected_size
85
+ max_size = z_matrix.shape[0] + 1
86
+
87
+ if z_matrix[i, 3] > max_size:
88
+ print(
89
+ f"Capping excessive cluster size at row {i}: {z_matrix[i, 3]} → {max_size}"
90
+ )
91
+ z_matrix[i, 3] = max_size
92
+
93
+ return np.asarray(z_matrix, dtype=np.float64)
@@ -0,0 +1,165 @@
1
+ import numpy as np
2
+ import pickle
3
+
4
+ class ScoreCalculator:
5
+ def __init__(self, Z_full, class_names):
6
+ # Check if Z_full is a file path (string) or numpy array
7
+ if isinstance(Z_full, str):
8
+ try:
9
+ with open(Z_full, 'rb') as file:
10
+ self.Z_full = pickle.load(file)
11
+ except (FileNotFoundError, pickle.UnpicklingError) as e:
12
+ print(f"Error loading Z file: {e}")
13
+ self.Z_full = None
14
+ elif isinstance(Z_full, np.ndarray):
15
+ # Z_full is already a numpy array
16
+ self.Z_full = Z_full
17
+ else:
18
+ print(f"Warning: Z_full type {type(Z_full)} not supported. Expected string (file path) or numpy array.")
19
+ self.Z_full = Z_full
20
+
21
+ self.class_names = class_names
22
+
23
+
24
+ def count_ancestors_to_lca(self, label1, label2):
25
+ """
26
+ Count the number of ancestors for each node until they reach their lowest common ancestor.
27
+
28
+ Parameters:
29
+ - Z_full: The full Z matrix from hierarchical clustering
30
+ - class_names: List of all class names
31
+ - label1, label2: Class labels to compare
32
+
33
+ Returns:
34
+ - total_count: Total number of ancestors traversed to reach LCA
35
+ """
36
+ # Convert labels to indices if needed
37
+ if isinstance(label1, str):
38
+ idx1 = self.class_names.index(label1)
39
+ else:
40
+ idx1 = label1
41
+
42
+ if isinstance(label2, str):
43
+ idx2 = self.class_names.index(label2)
44
+ else:
45
+ idx2 = label2
46
+
47
+ # Build the hierarchical structure
48
+ n_samples = len(self.class_names)
49
+ n_nodes = 2 * n_samples - 1
50
+
51
+ # Initialize parent mapping
52
+ parent = np.zeros(n_nodes, dtype=np.int64) - 1 # -1 means no parent
53
+
54
+ # Fill in the structure from Z
55
+ for i, (left, right, height, _) in enumerate(self.Z_full):
56
+ left = int(left)
57
+ right = int(right)
58
+ node_id = n_samples + i
59
+
60
+ parent[left] = node_id
61
+ parent[right] = node_id
62
+
63
+ # Trace path from node1 to root
64
+ path1 = []
65
+ current = idx1
66
+ while parent[current] != -1:
67
+ path1.append(parent[current])
68
+ current = parent[current]
69
+
70
+ # Trace path from node2 to LCA
71
+ path2 = []
72
+ current = idx2
73
+ lca = None
74
+
75
+ while parent[current] != -1:
76
+ current_parent = parent[current]
77
+ path2.append(current_parent)
78
+
79
+ if current_parent in path1:
80
+ # Found the LCA
81
+ lca = current_parent
82
+ break
83
+
84
+ current = current_parent
85
+
86
+ # If no LCA found (shouldn't happen in a proper hierarchy), return max value
87
+ if lca is None:
88
+ return n_nodes
89
+
90
+ # Count steps from node1 to LCA
91
+ steps1 = path1.index(lca) + 1
92
+
93
+ # Count steps from node2 to LCA
94
+ steps2 = path2.index(lca) + 1
95
+
96
+ # Total number of ancestors traversed
97
+ total_count = steps1 + steps2
98
+
99
+ return total_count, lca
100
+
101
+ def calculate_adversarial_score(self, predictions, top_k=5):
102
+ """
103
+ Calculate adversarial attack score based on statistical anomaly detection.
104
+ Works regardless of whether attacks increase or decrease semantic distance.
105
+
106
+ Parameters:
107
+ - Z_full: The full Z matrix from hierarchical clustering
108
+ - class_names: List of all class names
109
+ - predictions: Model output predictions (logits or probabilities)
110
+ - top_k: Number of top predictions to consider
111
+
112
+ Returns:
113
+ - Dictionary with score and detailed information
114
+ """
115
+ # Get top-k predictions
116
+ if len(predictions.shape) > 1:
117
+ predictions = predictions[0] # For batch predictions, take the first item
118
+
119
+
120
+ top_indices = np.argsort(predictions)[-top_k:][::-1]
121
+ top_probs = [predictions[i] for i in top_indices]
122
+ top_labels = [self.class_names[i] for i in top_indices]
123
+
124
+ # Calculate ancestors to LCA for all pairs
125
+ pairwise_distances = []
126
+ distance_prob_products = []
127
+ all_pairs = []
128
+
129
+ for i in range(len(top_indices)):
130
+ for j in range(i+1, len(top_indices)):
131
+ idx1, idx2 = top_indices[i], top_indices[j]
132
+ label1, label2 = top_labels[i], top_labels[j]
133
+ prob1, prob2 = top_probs[i], top_probs[j]
134
+
135
+ # Calculate semantic distance
136
+ rank_count, _ = self.count_ancestors_to_lca(idx1, idx2)
137
+ print(f"Rank Count for {label1} and {label2}: {rank_count}")
138
+
139
+ # Calculate product of probabilities and distance
140
+ prob_product = prob1 * prob2
141
+ rank_prob = rank_count * prob_product
142
+
143
+ pair_info = {
144
+ 'label1': label1,
145
+ 'label2': label2,
146
+ 'probability1': float(prob1),
147
+ 'probability2': float(prob2),
148
+ 'ancestor_distance': rank_count,
149
+ 'prob_product': prob_product,
150
+ 'weighted_distance': rank_prob
151
+ }
152
+
153
+ pairwise_distances.append(rank_count)
154
+ distance_prob_products.append(rank_prob)
155
+ all_pairs.append(pair_info)
156
+
157
+ # Calculate statistics
158
+ if pairwise_distances:
159
+ sum_distance = sum(pairwise_distances)
160
+ score = sum_distance
161
+ else:
162
+ score = 0
163
+
164
+ return score
165
+
@@ -0,0 +1,35 @@
1
+ import pandas as pd
2
+
3
+ class WhiteBoxTesting:
4
+ def __init__(self, model_name):
5
+ self.model_name = model_name
6
+ self.problematic_img_ids = None
7
+ self.problematic_img_preds = None
8
+
9
+
10
+ def find_problematic_images(self, source_labels, target_labels, edges_df, explanation_method=None):
11
+ filtered_edges_df = edges_df[
12
+ (edges_df['source'].isin(source_labels)) &
13
+ (edges_df['target'].isin(target_labels))
14
+ ]
15
+
16
+ filtered_edges_df_switched = edges_df[
17
+ (edges_df['source'].isin(target_labels)) &
18
+ (edges_df['target'].isin(source_labels))
19
+ ]
20
+ print(filtered_edges_df_switched.head())
21
+
22
+ combined_filtered_edges_df = pd.concat([filtered_edges_df, filtered_edges_df_switched])
23
+ print("Combined filtered edges dataset:")
24
+ print(combined_filtered_edges_df)
25
+
26
+ unique_ids_list = combined_filtered_edges_df['image_id'].unique().tolist()
27
+
28
+ matched_dict = {
29
+ image_id: list(zip(group['source'], group['target'], group['target_probability']))
30
+ for image_id, group in edges_df[edges_df['image_id'].isin(unique_ids_list)].groupby('image_id')
31
+ }
32
+ print("Matched dictionary:")
33
+ print(matched_dict)
34
+
35
+ return matched_dict
File without changes
@@ -0,0 +1,6 @@
1
+ from enum import Enum
2
+
3
+ class ExplanationMethod(Enum):
4
+ SIMILARITY = "similarity"
5
+ DISSIMILARITY = "dissimilarity"
6
+ COUNT = "count_based"
@@ -0,0 +1,5 @@
1
+ from enum import Enum
2
+
3
+ class HeapType(Enum):
4
+ MINIMUM = "min"
5
+ MAXIMUM = "max"
@@ -0,0 +1,18 @@
1
+ import numpy as np
2
+
3
+ def get_top_k_predictions(model, image, class_names, top_k=5):
4
+ # Get predictions from the model
5
+ predictions = model.predict(image)
6
+
7
+ # Flatten the predictions array if it's 2D (e.g., shape (1, num_classes))
8
+ if len(predictions.shape) == 2:
9
+ predictions = predictions[0] # Extract the first (and only) batch
10
+
11
+ # Get the indices of the top-k predictions
12
+ top_indices = np.argsort(predictions)[-top_k:][::-1]
13
+
14
+ # Get the top-k probabilities and corresponding labels
15
+ top_probs = predictions[top_indices]
16
+ top_labels = [class_names[i] for i in top_indices]
17
+
18
+ return list(zip(top_labels, top_probs))
@@ -0,0 +1,72 @@
1
+
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow.keras.applications.resnet50 import preprocess_input as resnet50_preprocess
5
+ from tensorflow.keras.applications.vgg16 import preprocess_input as vgg16_preprocess
6
+ from tensorflow.keras.applications.inception_v3 import preprocess_input as inception_v3_preprocess
7
+ from tensorflow.keras.applications.mobilenet import preprocess_input as mobilenet_preprocess
8
+ from tensorflow.keras.applications.efficientnet import preprocess_input as efficientnet_preprocess
9
+ from tensorflow.keras.applications.xception import preprocess_input as xception_preprocess
10
+ from PIL import Image
11
+ import io
12
+
13
+ def get_preprocess_function(model):
14
+ print("Determining preprocessing function based on model configuration...")
15
+ preprocess_map = {
16
+ "resnet50": resnet50_preprocess,
17
+ "vgg16": vgg16_preprocess,
18
+ "inception_v3": inception_v3_preprocess,
19
+ "mobilenet": mobilenet_preprocess,
20
+ "efficientnet": efficientnet_preprocess,
21
+ "xception": xception_preprocess,
22
+ }
23
+
24
+ model_config = model.get_config()
25
+ if "name" in model_config:
26
+ model_name = model_config["name"].lower()
27
+ print(f"Model name: {model_name}")
28
+ for key in preprocess_map.keys():
29
+ if key in model_name:
30
+ print(f"Detected model type: {key}")
31
+ return preprocess_map[key]
32
+
33
+ for layer in model.layers:
34
+ layer_name = layer.name.lower()
35
+ print(f"Checking layer: {layer_name}")
36
+ for model_name in preprocess_map.keys():
37
+ if model_name in layer_name:
38
+ print(f"Detected model type: {model_name}")
39
+ return preprocess_map[model_name]
40
+
41
+ print("No supported model type found in the configuration. Falling back to generic normalization.")
42
+ return lambda x: x / 255.0 # Generic normalization to [0, 1]
43
+
44
+
45
+ _cached_preprocess_function = {}
46
+
47
+ def get_cached_preprocess_function(model):
48
+ """
49
+ Get the cached preprocessing function for the given model.
50
+ If not cached, fetch it and store it in the cache.
51
+ """
52
+ global _cached_preprocess_function
53
+ model_id = id(model) # Use the model's unique ID as the cache key
54
+ if model_id not in _cached_preprocess_function:
55
+ _cached_preprocess_function[model_id] = get_preprocess_function(model)
56
+ return _cached_preprocess_function[model_id]
57
+
58
+ def preprocess_loaded_image(model, image):
59
+ expected_shape = model.input_shape
60
+ input_height, input_width = expected_shape[1], expected_shape[2]
61
+ pil_image = Image.open(io.BytesIO(image)).convert("RGB")
62
+ pil_image = pil_image.resize((input_width, input_height))
63
+ preprocess_input = get_cached_preprocess_function(model)
64
+ image_array = preprocess_input(np.array(pil_image))
65
+ image_preprocessed = np.expand_dims(image_array, axis=0)
66
+ return image_preprocessed
67
+
68
+ def preprocess_image(model, image):
69
+ preprocess_input = get_cached_preprocess_function(model)
70
+ image_array = preprocess_input(np.array(image))
71
+ image_preprocessed = np.expand_dims(image_array, axis=0)
72
+ return image_preprocessed
@@ -0,0 +1,104 @@
1
+
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow.keras.applications.resnet50 import preprocess_input as resnet50_preprocess
5
+ from tensorflow.keras.applications.vgg16 import preprocess_input as vgg16_preprocess
6
+ from tensorflow.keras.applications.inception_v3 import preprocess_input as inception_v3_preprocess
7
+ from tensorflow.keras.applications.mobilenet import preprocess_input as mobilenet_preprocess
8
+ from tensorflow.keras.applications.efficientnet import preprocess_input as efficientnet_preprocess
9
+ from tensorflow.keras.applications.xception import preprocess_input as xception_preprocess
10
+ from PIL import Image
11
+ import io
12
+ import base64
13
+
14
+ def get_preprocess_function(model):
15
+ print("Determining preprocessing function based on model configuration...")
16
+ preprocess_map = {
17
+ "resnet50": resnet50_preprocess,
18
+ "vgg16": vgg16_preprocess,
19
+ "inception_v3": inception_v3_preprocess,
20
+ "mobilenet": mobilenet_preprocess,
21
+ "efficientnet": efficientnet_preprocess,
22
+ "xception": xception_preprocess,
23
+ }
24
+
25
+ model_config = model.get_config()
26
+ if "name" in model_config:
27
+ model_name = model_config["name"].lower()
28
+ print(f"Model name: {model_name}")
29
+ for key in preprocess_map.keys():
30
+ if key in model_name:
31
+ print(f"Detected model type: {key}")
32
+ return preprocess_map[key]
33
+
34
+ for layer in model.layers:
35
+ layer_name = layer.name.lower()
36
+ print(f"Checking layer: {layer_name}")
37
+ for model_name in preprocess_map.keys():
38
+ if model_name in layer_name:
39
+ print(f"Detected model type: {model_name}")
40
+ return preprocess_map[model_name]
41
+
42
+ print("No supported model type found in the configuration. Falling back to generic normalization.")
43
+ return lambda x: x / 255.0 # Generic normalization to [0, 1]
44
+
45
+
46
+ _cached_preprocess_function = {}
47
+
48
+ def get_cached_preprocess_function(model):
49
+ global _cached_preprocess_function
50
+ model_id = id(model)
51
+ if model_id not in _cached_preprocess_function:
52
+ _cached_preprocess_function[model_id] = get_preprocess_function(model)
53
+ return _cached_preprocess_function[model_id]
54
+
55
+ def preprocess_loaded_image(model, image):
56
+ expected_shape = model.input_shape
57
+ input_height, input_width = expected_shape[1], expected_shape[2]
58
+
59
+ # Handle different input types
60
+ if isinstance(image, bytes):
61
+ # If image is bytes, convert to PIL Image
62
+ pil_image = Image.open(io.BytesIO(image)).convert("RGB")
63
+ elif isinstance(image, Image.Image):
64
+ # If image is already a PIL Image, just convert to RGB
65
+ pil_image = image.convert("RGB")
66
+ elif isinstance(image, np.ndarray):
67
+ # If image is numpy array, convert to PIL Image
68
+ if image.dtype == np.uint8:
69
+ pil_image = Image.fromarray(image)
70
+ else:
71
+ # Normalize to 0-255 if not uint8
72
+ image_normalized = ((image - image.min()) / (image.max() - image.min()) * 255).astype(np.uint8)
73
+ pil_image = Image.fromarray(image_normalized)
74
+ else:
75
+ raise ValueError(f"Unsupported image type: {type(image)}")
76
+
77
+ # pil_image = Image.open(io.BytesIO(image)).convert("RGB")
78
+ pil_image = pil_image.resize((input_width, input_height))
79
+ preprocess_input = get_cached_preprocess_function(model)
80
+ image_array = preprocess_input(np.array(pil_image))
81
+ image_preprocessed = np.expand_dims(image_array, axis=0)
82
+ return image_preprocessed, pil_image
83
+
84
+ def preprocess_image(model, image):
85
+ preprocess_input = get_cached_preprocess_function(model)
86
+ image_array = preprocess_input(np.array(image))
87
+ image_preprocessed = np.expand_dims(image_array, axis=0)
88
+ return image_preprocessed
89
+
90
+ def encode_image_to_base64(image):
91
+ if isinstance(image, np.ndarray):
92
+ if image.dtype != np.uint8:
93
+ if image.max() <= 1.0:
94
+ image = (image * 255).astype(np.uint8)
95
+ else:
96
+ image = image.astype(np.uint8)
97
+
98
+ pil_image = Image.fromarray(image)
99
+ buffered = io.BytesIO()
100
+ pil_image.save(buffered, format="PNG")
101
+ img_str = base64.b64encode(buffered.getvalue()).decode()
102
+ return img_str
103
+ else:
104
+ raise ValueError("Input must be a numpy array")
@@ -0,0 +1,15 @@
1
+ import re
2
+
3
+ def get_verbal_explanation(explanation):
4
+
5
+ result = []
6
+ for i, word in enumerate(explanation):
7
+ sanitized_word = re.sub(r'[0-9_]', '', word)
8
+ if i == 0:
9
+ result.append(f"**{sanitized_word}**") # You can use bold markdown or just the word
10
+ else:
11
+ prev_sanitized = re.sub(r'[0-9_]', '', explanation[i - 1])
12
+ if sanitized_word == prev_sanitized:
13
+ continue
14
+ result.append(f" is a part of **{sanitized_word}**")
15
+ return ''.join(result)