BETTER-NMA 1.0.1__py3-none-any.whl → 1.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- BETTER_NMA/explaination_score.py +52 -0
- BETTER_NMA/main.py +28 -3
- BETTER_NMA/nma_creator.py +3 -1
- BETTER_NMA/plot.py +5 -22
- BETTER_NMA/utilss/classes/adversarial_dataset.py +4 -9
- BETTER_NMA/utilss/classes/preprocessing/batch_predictor.py +4 -2
- BETTER_NMA/utilss/classes/score_calculator.py +0 -1
- BETTER_NMA/utilss/classes/whitebox_testing.py +13 -6
- BETTER_NMA/utilss/photos_uitls.py +3 -3
- BETTER_NMA/utilss/photos_utils.py +3 -3
- BETTER_NMA/utilss/wordnet_utils.py +123 -43
- BETTER_NMA/white_box_testing.py +1 -1
- {BETTER_NMA-1.0.1.dist-info → BETTER_NMA-1.0.5.dist-info}/METADATA +1 -1
- {BETTER_NMA-1.0.1.dist-info → BETTER_NMA-1.0.5.dist-info}/RECORD +16 -15
- {BETTER_NMA-1.0.1.dist-info → BETTER_NMA-1.0.5.dist-info}/WHEEL +0 -0
- {BETTER_NMA-1.0.1.dist-info → BETTER_NMA-1.0.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,52 @@
|
|
1
|
+
from .utilss.classes.score_calculator import ScoreCalculator
|
2
|
+
from itertools import combinations
|
3
|
+
|
4
|
+
def get_explaination_score(dendrogram, class_names, normalize=True):
|
5
|
+
"""
|
6
|
+
Get the score of the entire dendrogram based on pairwise LCA ancestor counts.
|
7
|
+
|
8
|
+
Parameters:
|
9
|
+
- dendrogram: The hierarchical clustering dendrogram
|
10
|
+
- class_names: List of class names corresponding to model outputs
|
11
|
+
|
12
|
+
Returns:
|
13
|
+
- score: The explanation score of the dendrogram
|
14
|
+
"""
|
15
|
+
score_calculator = ScoreCalculator(dendrogram.Z, class_names)
|
16
|
+
|
17
|
+
total_count = 0
|
18
|
+
num_classes = len(class_names)
|
19
|
+
|
20
|
+
# Get all combinations of 2 labels from class_names
|
21
|
+
for label1, label2 in combinations(class_names, 2):
|
22
|
+
try:
|
23
|
+
idx1 = class_names.index(label1)
|
24
|
+
idx2 = class_names.index(label2)
|
25
|
+
count, _ = score_calculator.count_ancestors_to_lca(idx1, idx2)
|
26
|
+
total_count += count
|
27
|
+
except ValueError as e:
|
28
|
+
print(f"Error processing pair ({label1}, {label2}): {e}")
|
29
|
+
continue
|
30
|
+
|
31
|
+
total_combinations = len(list(combinations(class_names, 2)))
|
32
|
+
print(f"Total combinations processed: {total_combinations}")
|
33
|
+
print(f"Total ancestor count sum: {total_count}")
|
34
|
+
|
35
|
+
if normalize:
|
36
|
+
# Maximum ancestors = height of dendrogram tree
|
37
|
+
max_ancestors_per_pair = num_classes - 1 # Maximum tree height
|
38
|
+
theoretical_max = total_combinations * max_ancestors_per_pair
|
39
|
+
|
40
|
+
if theoretical_max == 0:
|
41
|
+
normalized_score = 0
|
42
|
+
else:
|
43
|
+
# Invert the score so higher ancestor counts = lower explanation quality
|
44
|
+
# and normalize to 0-100%
|
45
|
+
normalized_score = max(0, (1 - (total_count / theoretical_max)) * 100)
|
46
|
+
|
47
|
+
print(f"Theoretical maximum: {theoretical_max}")
|
48
|
+
print(f"Normalized score: {normalized_score:.2f}%")
|
49
|
+
|
50
|
+
return normalized_score
|
51
|
+
else:
|
52
|
+
return total_count
|
BETTER_NMA/main.py
CHANGED
@@ -8,8 +8,11 @@ from .detect_attack import detect_adversarial_image
|
|
8
8
|
from .query_image import query_image
|
9
9
|
from .utilss.verbal_explanation import get_verbal_explanation
|
10
10
|
from .white_box_testing import analyze_white_box_results, get_white_box_analysis
|
11
|
+
from .explaination_score import get_explaination_score
|
11
12
|
from .adversarial_score import get_adversarial_score
|
12
13
|
from .find_lca import get_lca
|
14
|
+
from .utilss.wordnet_utils import synset_to_readable
|
15
|
+
import json
|
13
16
|
|
14
17
|
class NMA:
|
15
18
|
def __init__(self, x_train, y_train, labels, model, explanation_method, top_k=4, min_confidence=0.8, infinity=None, threshold=1e-6, save_connections=False, batch_size=32):
|
@@ -36,13 +39,12 @@ class NMA:
|
|
36
39
|
self.model = model
|
37
40
|
self.explanation_method = explanation_method
|
38
41
|
self.top_k = top_k
|
39
|
-
self.labels = labels
|
42
|
+
self.labels = [synset_to_readable(label) for label in labels]
|
40
43
|
self.min_confidence = min_confidence
|
41
44
|
self.infinity = infinity
|
42
45
|
self.threshold = threshold
|
43
46
|
self.save_connections = save_connections
|
44
47
|
self.batch_size = batch_size
|
45
|
-
self.labels = labels
|
46
48
|
self.detector = None
|
47
49
|
|
48
50
|
self.dendrogram_object, self.edges_df = preprocessing(x_train, y_train, labels, model, explanation_method, top_k, min_confidence, infinity, threshold, save_connections, batch_size)
|
@@ -81,6 +83,25 @@ class NMA:
|
|
81
83
|
"""
|
82
84
|
plot_sub_dendrogram(self.dendrogram_object.Z, self.labels, sub_labels, title=title, figsize=figsize)
|
83
85
|
|
86
|
+
|
87
|
+
def get_tree_as_dict(self, sub_labels=None):
|
88
|
+
"""
|
89
|
+
Returns the dendrogram hierarchy as a mutable Python dictionary.
|
90
|
+
|
91
|
+
Inputs:
|
92
|
+
- sub_labels (optional): List of labels to include in the subset.
|
93
|
+
|
94
|
+
Outputs: Dictionary representation of the dendrogram tree.
|
95
|
+
"""
|
96
|
+
if self.dendrogram_object is None:
|
97
|
+
raise ValueError("Dendrogram not available.")
|
98
|
+
|
99
|
+
if sub_labels is None:
|
100
|
+
sub_labels = self.labels
|
101
|
+
|
102
|
+
json_str = self.dendrogram_object.get_sub_dendrogram_formatted(sub_labels)
|
103
|
+
return json.loads(json_str)
|
104
|
+
|
84
105
|
## white box testing functions: ##
|
85
106
|
|
86
107
|
def white_box_testing(self, source_labels, target_labels, analyze_results=False, x_train=None, encode_images=True):
|
@@ -101,7 +122,7 @@ class NMA:
|
|
101
122
|
if self.edges_df is None:
|
102
123
|
raise ValueError("White box testing requires edges_df. Initialize NMA with save_connections=True")
|
103
124
|
|
104
|
-
whitebox = WhiteBoxTesting(self.model.name if hasattr(self.model, 'name') else "model")
|
125
|
+
whitebox = WhiteBoxTesting(self.model.name if hasattr(self.model, 'name') else "model", verbose=False)
|
105
126
|
problematic_imgs_dict = whitebox.find_problematic_images(
|
106
127
|
source_labels, target_labels, self.edges_df, self.explanation_method)
|
107
128
|
|
@@ -212,6 +233,10 @@ class NMA:
|
|
212
233
|
score = get_adversarial_score(image, self.model, self.dendrogram_object.Z, self.labels, top_k=top_k)
|
213
234
|
return score
|
214
235
|
|
236
|
+
def explanation_score(self, normalize=True):
|
237
|
+
return get_explaination_score(self.dendrogram_object, self.labels, normalize=normalize)
|
238
|
+
|
239
|
+
|
215
240
|
## query and explanation functions: ##
|
216
241
|
|
217
242
|
def query_image(self, image, top_k=5):
|
BETTER_NMA/nma_creator.py
CHANGED
@@ -8,11 +8,13 @@ from .utilss.classes.preprocessing.graph_builder import GraphBuilder
|
|
8
8
|
from .utilss.classes.preprocessing.hierarchical_clustering_builder import HierarchicalClusteringBuilder
|
9
9
|
from .utilss.classes.preprocessing.z_builder import ZBuilder
|
10
10
|
from .utilss.classes.dendrogram import Dendrogram
|
11
|
+
from .utilss.wordnet_utils import synset_to_readable
|
11
12
|
|
12
13
|
def preprocessing(x_train, y_train, labels, model, explanation_method, top_k, min_confidence, infinity, threshold, save_connections, batch_size=32):
|
13
14
|
try:
|
14
15
|
X = x_train
|
15
|
-
y = y_train
|
16
|
+
y = [synset_to_readable(l) for l in y_train]
|
17
|
+
labels = [synset_to_readable(l) for l in labels]
|
16
18
|
|
17
19
|
graph = Graph(directed=False)
|
18
20
|
graph.add_vertices(labels)
|
BETTER_NMA/plot.py
CHANGED
@@ -101,31 +101,14 @@ def plot(nma_instance, sub_labels, title, figsize, **kwargs):
|
|
101
101
|
raise ValueError("No linkage matrix (z) found in NMA instance")
|
102
102
|
|
103
103
|
if sub_labels is None:
|
104
|
-
sub_labels
|
104
|
+
print("No sub_labels provided.")
|
105
|
+
return
|
105
106
|
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
print(filtered_dendrogram_json)
|
107
|
+
_ = nma_instance.dendrogram_object.get_sub_dendrogram_formatted(sub_labels)
|
108
|
+
# filtered_dendrogram_json = nma_instance.dendrogram_object.get_sub_dendrogram_formatted(
|
109
|
+
# sub_labels)
|
110
110
|
|
111
111
|
if hasattr(nma_instance, 'labels'):
|
112
112
|
plot_sub_dendrogram(nma_instance.dendrogram_object.Z,
|
113
113
|
nma_instance.labels, sub_labels, title, figsize)
|
114
114
|
|
115
|
-
print(nma_instance.dendrogram_object.Z)
|
116
|
-
|
117
|
-
# plt.figure(figsize=(20, 15))
|
118
|
-
# sch.dendrogram(
|
119
|
-
# nma_instance.z,
|
120
|
-
# leaf_rotation=0,
|
121
|
-
# leaf_font_size=10,
|
122
|
-
# orientation='right',
|
123
|
-
# color_threshold=85,
|
124
|
-
# **kwargs
|
125
|
-
# )
|
126
|
-
# plt.title('NMA Hierarchical Clustering Dendrogram')
|
127
|
-
# plt.xlabel('Elements')
|
128
|
-
# plt.ylabel('Distance')
|
129
|
-
# plt.tight_layout()
|
130
|
-
# plt.show()
|
131
|
-
|
@@ -13,32 +13,27 @@ class AdversarialDataset:
|
|
13
13
|
scores = []
|
14
14
|
labels = []
|
15
15
|
|
16
|
-
print("getting preprocess function...")
|
17
|
-
|
18
16
|
try:
|
19
17
|
for image in self.clear_images[:50]:
|
20
18
|
# Add batch dimension for model prediction
|
21
19
|
image_batch = np.expand_dims(image, axis=0)
|
22
|
-
score = self.score_calculator.calculate_adversarial_score(self.model.predict(image_batch))
|
20
|
+
score = self.score_calculator.calculate_adversarial_score(self.model.predict(image_batch, verbose=0))
|
23
21
|
scores.append(score)
|
24
22
|
labels.append(0)
|
25
23
|
except Exception as e:
|
26
|
-
print(f"Error processing clean
|
24
|
+
print(f"Error processing clean images: {e}")
|
27
25
|
|
28
26
|
# Generate features for PGD attacks
|
29
|
-
print("Generating attack features...")
|
30
27
|
try:
|
31
28
|
for adv_image in self.adversarial_images[:50]:
|
32
29
|
# Add batch dimension for model prediction
|
33
30
|
adv_image_batch = np.expand_dims(adv_image, axis=0)
|
34
|
-
score = self.score_calculator.calculate_adversarial_score(self.model.predict(adv_image_batch))
|
31
|
+
score = self.score_calculator.calculate_adversarial_score(self.model.predict(adv_image_batch, verbose=0))
|
35
32
|
scores.append(score)
|
36
33
|
labels.append(1)
|
37
34
|
except Exception as e:
|
38
|
-
print(f"Error processing
|
35
|
+
print(f"Error processing attacked images: {e}")
|
39
36
|
|
40
|
-
print("labels:", labels)
|
41
|
-
print("scores:", scores)
|
42
37
|
|
43
38
|
# Convert to numpy arrays
|
44
39
|
X = np.array(scores)
|
@@ -1,5 +1,6 @@
|
|
1
1
|
import numpy as np
|
2
2
|
import tensorflow as tf
|
3
|
+
from ...wordnet_utils import synset_to_readable
|
3
4
|
|
4
5
|
|
5
6
|
class BatchPredictor:
|
@@ -11,14 +12,15 @@ class BatchPredictor:
|
|
11
12
|
self.buffer_results = [] # To store batch results
|
12
13
|
|
13
14
|
def get_top_predictions(self, X, labels, top_k, graph_threshold):
|
14
|
-
batch_preds = self.model.predict(np.array(X))
|
15
|
+
batch_preds = self.model.predict(np.array(X), verbose=0)
|
15
16
|
batch_results = []
|
16
17
|
for pred in batch_preds:
|
17
18
|
top_indices = pred.argsort()[-top_k:][::-1]
|
18
19
|
valid_indices = [i for i in top_indices if i < len(labels)]
|
19
20
|
|
20
21
|
top_predictions = [
|
21
|
-
(i, labels[i], pred[i])
|
22
|
+
# (i, labels[i], pred[i])
|
23
|
+
(i, synset_to_readable(labels[i]), pred[i])
|
22
24
|
for i in valid_indices
|
23
25
|
if pred[i] >= graph_threshold
|
24
26
|
]
|
@@ -134,7 +134,6 @@ class ScoreCalculator:
|
|
134
134
|
|
135
135
|
# Calculate semantic distance
|
136
136
|
rank_count, _ = self.count_ancestors_to_lca(idx1, idx2)
|
137
|
-
print(f"Rank Count for {label1} and {label2}: {rank_count}")
|
138
137
|
|
139
138
|
# Calculate product of probabilities and distance
|
140
139
|
prob_product = prob1 * prob2
|
@@ -1,10 +1,11 @@
|
|
1
1
|
import pandas as pd
|
2
2
|
|
3
3
|
class WhiteBoxTesting:
|
4
|
-
def __init__(self, model_name):
|
4
|
+
def __init__(self, model_name, verbose=False):
|
5
5
|
self.model_name = model_name
|
6
6
|
self.problematic_img_ids = None
|
7
7
|
self.problematic_img_preds = None
|
8
|
+
self.verbose = verbose
|
8
9
|
|
9
10
|
|
10
11
|
def find_problematic_images(self, source_labels, target_labels, edges_df, explanation_method=None):
|
@@ -17,11 +18,15 @@ class WhiteBoxTesting:
|
|
17
18
|
(edges_df['source'].isin(target_labels)) &
|
18
19
|
(edges_df['target'].isin(source_labels))
|
19
20
|
]
|
20
|
-
|
21
|
+
|
22
|
+
if self.verbose:
|
23
|
+
print(filtered_edges_df_switched.head())
|
21
24
|
|
22
25
|
combined_filtered_edges_df = pd.concat([filtered_edges_df, filtered_edges_df_switched])
|
23
|
-
|
24
|
-
|
26
|
+
|
27
|
+
if self.verbose:
|
28
|
+
print("Combined filtered edges dataset:")
|
29
|
+
print(combined_filtered_edges_df)
|
25
30
|
|
26
31
|
unique_ids_list = combined_filtered_edges_df['image_id'].unique().tolist()
|
27
32
|
|
@@ -29,7 +34,9 @@ class WhiteBoxTesting:
|
|
29
34
|
image_id: list(zip(group['source'], group['target'], group['target_probability']))
|
30
35
|
for image_id, group in edges_df[edges_df['image_id'].isin(unique_ids_list)].groupby('image_id')
|
31
36
|
}
|
32
|
-
|
33
|
-
|
37
|
+
|
38
|
+
if self.verbose:
|
39
|
+
print("Matched dictionary:")
|
40
|
+
print(matched_dict)
|
34
41
|
|
35
42
|
return matched_dict
|
@@ -24,10 +24,10 @@ def get_preprocess_function(model):
|
|
24
24
|
model_config = model.get_config()
|
25
25
|
if "name" in model_config:
|
26
26
|
model_name = model_config["name"].lower()
|
27
|
-
print(f"Model name: {model_name}")
|
27
|
+
# print(f"Model name: {model_name}")
|
28
28
|
for key in preprocess_map.keys():
|
29
29
|
if key in model_name:
|
30
|
-
print(f"Detected model type: {key}")
|
30
|
+
# print(f"Detected model type: {key}")
|
31
31
|
return preprocess_map[key]
|
32
32
|
|
33
33
|
for layer in model.layers:
|
@@ -35,7 +35,7 @@ def get_preprocess_function(model):
|
|
35
35
|
print(f"Checking layer: {layer_name}")
|
36
36
|
for model_name in preprocess_map.keys():
|
37
37
|
if model_name in layer_name:
|
38
|
-
print(f"Detected model type: {model_name}")
|
38
|
+
# print(f"Detected model type: {model_name}")
|
39
39
|
return preprocess_map[model_name]
|
40
40
|
|
41
41
|
print("No supported model type found in the configuration. Falling back to generic normalization.")
|
@@ -25,15 +25,15 @@ def get_preprocess_function(model):
|
|
25
25
|
model_config = model.get_config()
|
26
26
|
if "name" in model_config:
|
27
27
|
model_name = model_config["name"].lower()
|
28
|
-
print(f"Model name: {model_name}")
|
28
|
+
# print(f"Model name: {model_name}")
|
29
29
|
for key in preprocess_map.keys():
|
30
30
|
if key in model_name:
|
31
|
-
print(f"Detected model type: {key}")
|
31
|
+
# print(f"Detected model type: {key}")
|
32
32
|
return preprocess_map[key]
|
33
33
|
|
34
34
|
for layer in model.layers:
|
35
35
|
layer_name = layer.name.lower()
|
36
|
-
print(f"Checking layer: {layer_name}")
|
36
|
+
# print(f"Checking layer: {layer_name}")
|
37
37
|
for model_name in preprocess_map.keys():
|
38
38
|
if model_name in layer_name:
|
39
39
|
print(f"Detected model type: {model_name}")
|
@@ -12,9 +12,32 @@ def folder_name_to_number(folder_name):
|
|
12
12
|
folder_number = 'n{:08d}'.format(offset)
|
13
13
|
return folder_number
|
14
14
|
|
15
|
+
def synset_to_readable(label):
|
16
|
+
# Check if label is in synset format
|
17
|
+
if isinstance(label, str) and label.startswith('n') and label[1:].isdigit():
|
18
|
+
special_cases = {
|
19
|
+
"n02012849": "crane bird", # Bird
|
20
|
+
"n03126707": "crane machine", # Vehicle
|
21
|
+
"n03710637": "maillot", # Swimsuit
|
22
|
+
"n03710721": "tank suit" # Swimsuit
|
23
|
+
}
|
24
|
+
|
25
|
+
if label in special_cases:
|
26
|
+
return special_cases[label]
|
27
|
+
|
28
|
+
try:
|
29
|
+
offset = int(label[1:])
|
30
|
+
synset = wn.synset_from_pos_and_offset('n', offset)
|
31
|
+
return synset.lemma_names()[0].replace('_', ' ')
|
32
|
+
except Exception:
|
33
|
+
return label # fallback if not found
|
34
|
+
else:
|
35
|
+
return label # already readable
|
36
|
+
|
15
37
|
def common_group(groups):
|
16
38
|
common_hypernyms = []
|
17
39
|
hierarchy = {}
|
40
|
+
|
18
41
|
for group in groups:
|
19
42
|
hierarchy[group] = []
|
20
43
|
synsets = wn.synsets(group)
|
@@ -29,48 +52,118 @@ def common_group(groups):
|
|
29
52
|
for hypernym in hierarchy[groups.pop()]:
|
30
53
|
if all(hypernym in hypernyms for hypernyms in hierarchy.values()):
|
31
54
|
common_hypernyms.append(hypernym)
|
55
|
+
|
32
56
|
return common_hypernyms[::-1]
|
33
57
|
|
58
|
+
def process_hierarchy(hierarchy_data,):
|
59
|
+
"""Process the entire hierarchy, renaming clusters while preserving structure."""
|
60
|
+
return _rename_clusters(hierarchy_data)
|
34
61
|
|
35
62
|
def get_all_leaf_names(node):
|
63
|
+
"""Extract all leaf node names from a cluster hierarchy."""
|
36
64
|
if "children" not in node:
|
65
|
+
# Only return actual object names, not cluster names
|
37
66
|
if "cluster" not in node["name"]:
|
38
67
|
return [node["name"]]
|
39
68
|
return []
|
69
|
+
|
40
70
|
names = []
|
41
71
|
for child in node["children"]:
|
42
72
|
names.extend(get_all_leaf_names(child))
|
43
73
|
return names
|
74
|
+
|
75
|
+
def _rename_clusters(tree):
|
76
|
+
"""
|
77
|
+
Traverse the tree in BFS manner and rename clusters based on child names,
|
78
|
+
which can be leaves or already-renamed clusters.
|
79
|
+
"""
|
80
|
+
used_names = set()
|
81
|
+
all_leaf_names = {leaf.lower() for leaf in get_all_leaf_names(tree)}
|
44
82
|
|
83
|
+
queue = deque()
|
84
|
+
queue.append(tree)
|
45
85
|
|
46
|
-
|
47
|
-
|
86
|
+
# BFS traversal, we store nodes with children in postprocess queue
|
87
|
+
postprocess_nodes = []
|
48
88
|
|
89
|
+
while queue:
|
90
|
+
node = queue.popleft()
|
91
|
+
if "children" in node:
|
92
|
+
queue.extend(node["children"])
|
93
|
+
postprocess_nodes.append(node) # non-leaf clusters to process after children
|
49
94
|
|
50
|
-
|
95
|
+
# Process clusters in reverse BFS (bottom-up)
|
96
|
+
for node in reversed(postprocess_nodes):
|
97
|
+
if "cluster" not in node["name"]:
|
98
|
+
continue # already renamed
|
99
|
+
|
100
|
+
# Collect child names (renamed or original leaves)
|
101
|
+
child_names = [child["name"] for child in node["children"] if "name" in child]
|
102
|
+
|
103
|
+
# Get hypernym candidate from child names
|
104
|
+
candidate = find_common_hypernyms(child_names)
|
105
|
+
if candidate:
|
106
|
+
# Ensure it’s unique
|
107
|
+
base = candidate
|
108
|
+
unique = base
|
109
|
+
idx = 1
|
110
|
+
while unique.lower() in all_leaf_names or unique.lower() in {n.lower() for n in used_names}:
|
111
|
+
idx += 1
|
112
|
+
unique = f"{base}_{idx}"
|
113
|
+
node["name"] = unique
|
114
|
+
used_names.add(unique)
|
115
|
+
|
116
|
+
return tree
|
117
|
+
|
118
|
+
def _get_top_synsets(
|
119
|
+
phrase: str,
|
120
|
+
pos=wn.NOUN,
|
121
|
+
max_senses: int = 15
|
122
|
+
) -> list[wn.synset]:
|
123
|
+
"""
|
124
|
+
Return up to `max_senses` synsets for `phrase`.
|
125
|
+
- Replaces spaces/underscores so WordNet can match “pickup truck” or “aquarium_fish”.
|
126
|
+
- WordNet already orders synsets by frequency, so we take only the first few.
|
127
|
+
"""
|
51
128
|
lemma = phrase.strip().lower().replace(" ", "_")
|
52
129
|
syns = wn.synsets(lemma, pos=pos)
|
53
130
|
return syns[:max_senses] if syns else []
|
54
131
|
|
55
132
|
|
133
|
+
# ---------------------------------------------------
|
134
|
+
# Core: compute the single best hypernym for a set of words
|
135
|
+
# ---------------------------------------------------
|
56
136
|
def _find_best_common_hypernym(
|
57
137
|
leaves: list[str],
|
58
138
|
max_senses_per_word: int = 5,
|
59
|
-
banned_lemmas: set[str] = None
|
60
|
-
) ->
|
61
|
-
|
139
|
+
banned_lemmas: set[str] = None,
|
140
|
+
) -> str | None:
|
141
|
+
"""
|
142
|
+
1. For each leaf in `leaves`, fetch up to `max_senses_per_word` synsets.
|
143
|
+
2. For EVERY pair of leaves (w1, w2), for EVERY combination of synset ∈ synsets(w1) × synsets(w2),
|
144
|
+
call syn1.lowest_common_hypernyms(syn2) → yields a list of shared hypernyms.
|
145
|
+
Tally them in `lch_counter`.
|
146
|
+
3. Sort the candidates by (frequency, min_depth) so we pick the most-specific, most-common ancestor.
|
147
|
+
4. Filter out overly generic lemmas (like “entity”, “object”) unless NOTHING else remains.
|
148
|
+
5. Return the best lemma_name (underscore → space, capitalized).
|
149
|
+
"""
|
62
150
|
if banned_lemmas is None:
|
63
151
|
banned_lemmas = {"entity", "object", "physical_entity", "thing", "Object", "Whole", "Whole", "Physical_entity", "Thing", "Entity", "Artifact"}
|
64
|
-
|
152
|
+
|
153
|
+
|
154
|
+
# 1. Map each leaf → up to `max_senses_per_word` synsets
|
65
155
|
word_to_synsets: dict[str, list[wn.synset]] = {}
|
66
156
|
for w in leaves:
|
67
157
|
syns = _get_top_synsets(w, wn.NOUN, max_senses_per_word)
|
68
158
|
if syns:
|
69
159
|
word_to_synsets[w] = syns
|
70
160
|
|
161
|
+
# If fewer than 2 words have ANY synsets, we cannot get a meaningful common hypernym
|
71
162
|
if len(word_to_synsets) < 2:
|
72
163
|
return None
|
73
164
|
|
165
|
+
# 2. For each pair of distinct leaves w1, w2, do ALL combinations of synset₁ × synset₂
|
166
|
+
# and tally lowest_common_hypernyms
|
74
167
|
lch_counter: Counter[wn.synset] = Counter()
|
75
168
|
words_list = list(word_to_synsets.keys())
|
76
169
|
|
@@ -82,7 +175,6 @@ def _find_best_common_hypernym(
|
|
82
175
|
try:
|
83
176
|
common = s1.lowest_common_hypernyms(s2)
|
84
177
|
except Exception as e:
|
85
|
-
print(f"Error computing LCH({s1.name()}, {s2.name()}): {e}")
|
86
178
|
continue
|
87
179
|
for hyp in common:
|
88
180
|
lch_counter[hyp] += 1
|
@@ -90,12 +182,14 @@ def _find_best_common_hypernym(
|
|
90
182
|
if not lch_counter:
|
91
183
|
return None
|
92
184
|
|
185
|
+
# 3. Sort candidates by (frequency, min_depth) descending
|
93
186
|
candidates = sorted(
|
94
187
|
lch_counter.items(),
|
95
188
|
key=lambda item: (item[1], item[0].min_depth()),
|
96
189
|
reverse=True
|
97
190
|
)
|
98
191
|
|
192
|
+
# 4. Filter out generic lemma_names unless NOTHING else remains
|
99
193
|
filtered: list[tuple[wn.synset, int]] = []
|
100
194
|
for syn, freq in candidates:
|
101
195
|
lemma = syn.name().split(".")[0].lower()
|
@@ -103,75 +197,61 @@ def _find_best_common_hypernym(
|
|
103
197
|
continue
|
104
198
|
filtered.append((syn, freq))
|
105
199
|
|
200
|
+
# If every candidate was filtered out, allow the first generic anyway
|
106
201
|
if not filtered:
|
107
202
|
filtered = candidates
|
108
203
|
|
109
204
|
best_synset, best_freq = filtered[0]
|
110
205
|
best_label = (best_synset.name().split(".")[0].replace(" ", "_")).lower()
|
206
|
+
|
111
207
|
return best_label
|
112
208
|
|
209
|
+
|
210
|
+
# ---------------------------------------------------
|
211
|
+
# Public version: branching on single vs. multiple leaves
|
212
|
+
# ---------------------------------------------------
|
113
213
|
def find_common_hypernyms(
|
114
214
|
words: list[str],
|
115
215
|
abstraction_level: int = 0,
|
116
|
-
) ->
|
117
|
-
|
216
|
+
) -> str | None:
|
217
|
+
"""
|
218
|
+
Improved drop-in replacement for your old `find_common_hypernyms`.
|
219
|
+
1. Normalize each word (underscores ↔ spaces, lowercase) and filter out anything containing "Cluster".
|
220
|
+
2. If there’s exactly one valid leaf, pick its first hypernym (one level up) unless it’s “entity”.
|
221
|
+
3. If there are 2+ leaves, call _find_best_common_hypernym on them.
|
222
|
+
"""
|
223
|
+
|
118
224
|
clean_leaves = [
|
225
|
+
# w.strip().lower().replace(" ", "_")
|
119
226
|
re.sub(r'_\d+$', '', w.strip().lower().replace(" ", "_"))
|
120
227
|
for w in words
|
121
228
|
if w and "cluster" not in w.lower()
|
122
229
|
]
|
123
230
|
|
231
|
+
# If nothing remains, bail out
|
124
232
|
if not clean_leaves:
|
125
233
|
return None
|
126
234
|
|
235
|
+
# Single-word case: pick its immediate hypernym (second-to-bottom in the hypernym path)
|
127
236
|
if len(clean_leaves) == 1:
|
128
237
|
word = clean_leaves[0]
|
129
238
|
synsets = _get_top_synsets(word, wn.NOUN, max_senses=10)
|
130
239
|
if not synsets:
|
131
240
|
return None
|
132
241
|
|
242
|
+
# Choose the first sense’s longest hypernym path, then take one level up from leaf sense.
|
133
243
|
paths = synsets[0].hypernym_paths() # list of lists
|
134
244
|
if not paths:
|
135
245
|
return None
|
136
246
|
|
137
247
|
longest_path = max(paths, key=lambda p: len(p))
|
248
|
+
# If path has at least 2 nodes, candidate = one level above the leaf sense
|
138
249
|
if len(longest_path) >= 2:
|
139
250
|
candidate = longest_path[-2]
|
140
251
|
name = (candidate.name().split(".")[0].replace(" ", "_")).lower()
|
141
252
|
if name.lower() not in {word, "entity"}:
|
142
253
|
return name
|
143
254
|
return None
|
144
|
-
return _find_best_common_hypernym(clean_leaves, max_senses_per_word=5)
|
145
|
-
|
146
|
-
|
147
|
-
def _rename_clusters(tree):
|
148
|
-
used_names = set()
|
149
|
-
all_leaf_names = {leaf.lower() for leaf in get_all_leaf_names(tree)}
|
150
|
-
queue = deque()
|
151
|
-
queue.append(tree)
|
152
|
-
postprocess_nodes = []
|
153
|
-
|
154
|
-
while queue:
|
155
|
-
node = queue.popleft()
|
156
|
-
if "children" in node:
|
157
|
-
queue.extend(node["children"])
|
158
|
-
postprocess_nodes.append(node)
|
159
|
-
|
160
|
-
for node in reversed(postprocess_nodes):
|
161
|
-
if "cluster" not in node["name"]:
|
162
|
-
continue
|
163
255
|
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
if candidate:
|
168
|
-
base = candidate
|
169
|
-
unique = base
|
170
|
-
idx = 1
|
171
|
-
while unique.lower() in all_leaf_names or unique.lower() in {n.lower() for n in used_names}:
|
172
|
-
idx += 1
|
173
|
-
unique = f"{base}_{idx}"
|
174
|
-
node["name"] = unique
|
175
|
-
used_names.add(unique)
|
176
|
-
|
177
|
-
return tree
|
256
|
+
# 2+ leaves: use pairwise LCA approach
|
257
|
+
return _find_best_common_hypernym(clean_leaves, max_senses_per_word=5)
|
BETTER_NMA/white_box_testing.py
CHANGED
@@ -57,7 +57,7 @@ def get_white_box_analysis(edges_df_path, model_filename, dataset_str, source_la
|
|
57
57
|
edges_data.load_dataframe()
|
58
58
|
df = edges_data.get_dataframe()
|
59
59
|
|
60
|
-
whitebox_testing = WhiteBoxTesting(model_filename)
|
60
|
+
whitebox_testing = WhiteBoxTesting(model_filename, verbose=False)
|
61
61
|
problematic_imgs_dict = whitebox_testing.find_problematic_images(
|
62
62
|
source_labels, target_labels, df, dataset_str)
|
63
63
|
|
@@ -2,28 +2,29 @@ BETTER_NMA/__init__.py,sha256=ePaQnto0n4hccz2490Z7bxwcbtONVAa6nWqg7SL4W1Y,428
|
|
2
2
|
BETTER_NMA/adversarial_score.py,sha256=qgScTqS-aJ2q4kFom505hBtonVzKK67fGS09J1_-G3o,875
|
3
3
|
BETTER_NMA/change_cluster_name.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
BETTER_NMA/detect_attack.py,sha256=s7YwTVMJFABSMt2aISR-zaIUxFaSWm9oODc9yF12KPY,4327
|
5
|
+
BETTER_NMA/explaination_score.py,sha256=GWtncYjIj8UF94g1UxtNRnyK7Cvh3PqVcShle_vLhUM,2003
|
5
6
|
BETTER_NMA/find_lca.py,sha256=UlyftOJmbSPuXzxvcheRb_IrdCqBsaSQHLchIRZIR-0,812
|
6
|
-
BETTER_NMA/main.py,sha256=
|
7
|
-
BETTER_NMA/nma_creator.py,sha256=
|
8
|
-
BETTER_NMA/plot.py,sha256=
|
7
|
+
BETTER_NMA/main.py,sha256=c-JZX3F2Ozh-ctVtlnB19hN4gptxt-4wQsecBzD-fSU,12258
|
8
|
+
BETTER_NMA/nma_creator.py,sha256=mJRRFX2d5pUXORT0jKpMV5QW8FFqn1bJgUmyqPVnMQA,5282
|
9
|
+
BETTER_NMA/plot.py,sha256=ySNXrlQgwV6gm2Dw6yWT-dfrY7A7mRTJNUgsd6sOxic,3954
|
9
10
|
BETTER_NMA/query_image.py,sha256=13AQ9-8QdzaIwH5-ELX3z3iJBP8nTDe-SMtwQve-1ek,906
|
10
11
|
BETTER_NMA/train_adversarial_detector.py,sha256=nMaQ-Pm2vP84qNR1GoKQiVPpmMC3rdorzDMf5gDwKTE,977
|
11
|
-
BETTER_NMA/white_box_testing.py,sha256=
|
12
|
+
BETTER_NMA/white_box_testing.py,sha256=VZ5pImXUOpM6jWMOoIkTWymwPCsev75zQ2SudSJ0frw,3539
|
12
13
|
BETTER_NMA/utilss/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
13
14
|
BETTER_NMA/utilss/models_utils.py,sha256=gBXY2LFH4iR-2GZmHeUnnB5n9t3VdjIc9sugHDrD3AM,671
|
14
|
-
BETTER_NMA/utilss/photos_uitls.py,sha256=
|
15
|
-
BETTER_NMA/utilss/photos_utils.py,sha256=
|
15
|
+
BETTER_NMA/utilss/photos_uitls.py,sha256=wxmIIKFgAKYkcYaK95UMjtY-LZS6NDVveKrHBQV8Q70,3166
|
16
|
+
BETTER_NMA/utilss/photos_utils.py,sha256=4EjDHbMjrJ8P9y-X4H05P4wez4uKNit60UGnu3sKsys,4412
|
16
17
|
BETTER_NMA/utilss/verbal_explanation.py,sha256=_hrYZUjBUYOfuGr7t5r-DACooR5d60dRtGfUj7FbeZw,549
|
17
|
-
BETTER_NMA/utilss/wordnet_utils.py,sha256=
|
18
|
+
BETTER_NMA/utilss/wordnet_utils.py,sha256=77qcmEQH3Krd1T8dQY-IXVpaEgfwlw406XRk4zYsghw,9482
|
18
19
|
BETTER_NMA/utilss/classes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
BETTER_NMA/utilss/classes/adversarial_dataset.py,sha256=
|
20
|
+
BETTER_NMA/utilss/classes/adversarial_dataset.py,sha256=LKmyseQetyBdoKrl7Q4MaTc7EzMZcMgqNdtCsdAvKHA,2299
|
20
21
|
BETTER_NMA/utilss/classes/adversarial_detector.py,sha256=BE_SxNEwcvuHERBiefefOmk1k6NJSo6juehkAjkEHuQ,2331
|
21
22
|
BETTER_NMA/utilss/classes/dendrogram.py,sha256=vtKBFfwzcz8k01Goc83pZlWC2pO86endTJURlkUWVQI,5141
|
22
23
|
BETTER_NMA/utilss/classes/edges_dataframe.py,sha256=q-RQ6beOeZeIgdEzwi8T5Ag2NBFySv7-ITD5m989nl4,1896
|
23
|
-
BETTER_NMA/utilss/classes/score_calculator.py,sha256=
|
24
|
-
BETTER_NMA/utilss/classes/whitebox_testing.py,sha256=
|
24
|
+
BETTER_NMA/utilss/classes/score_calculator.py,sha256=zgZaFgFJeok2RzXFm9OE5pOuXY3euIXEiLGI44q30JM,5927
|
25
|
+
BETTER_NMA/utilss/classes/whitebox_testing.py,sha256=_WsJWUSSb7Ndvzw2ftM7QKXiz_8NXOMS3njBKPAM4Rw,1542
|
25
26
|
BETTER_NMA/utilss/classes/preprocessing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
26
|
-
BETTER_NMA/utilss/classes/preprocessing/batch_predictor.py,sha256=
|
27
|
+
BETTER_NMA/utilss/classes/preprocessing/batch_predictor.py,sha256=3-hs9Mr30B0n9PQZE20-3RoLlsS5dw56yQX2MU4UnT4,1070
|
27
28
|
BETTER_NMA/utilss/classes/preprocessing/graph_builder.py,sha256=ILumiBY9BUIOxrIvq8C-8n945pK-t94Et6gZwJB-364,1672
|
28
29
|
BETTER_NMA/utilss/classes/preprocessing/heap_processor.py,sha256=KblmkVWVfMYtpZa4Wy1Ry0lVfdSr6h8LySt4S-lvIGo,1064
|
29
30
|
BETTER_NMA/utilss/classes/preprocessing/hierarchical_clustering_builder.py,sha256=YAIElJS_fSffIb3D2N1OZu9U6z7RYrHQTfB6bH4-VPI,4027
|
@@ -32,7 +33,7 @@ BETTER_NMA/utilss/classes/preprocessing/z_builder.py,sha256=T8ETfL7mMOgEj7oYNsw6
|
|
32
33
|
BETTER_NMA/utilss/enums/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
33
34
|
BETTER_NMA/utilss/enums/explanation_method.py,sha256=Ang-rjvxO4AJ1IH4mwS8sNpSwt9jn3PlqFbPPT-R9I8,150
|
34
35
|
BETTER_NMA/utilss/enums/heap_types.py,sha256=0z1d2qu1ZCbpWRXKD1dTopn3M4G1CxRQW9HWxVxyPIA,88
|
35
|
-
BETTER_NMA-1.0.
|
36
|
-
BETTER_NMA-1.0.
|
37
|
-
BETTER_NMA-1.0.
|
38
|
-
BETTER_NMA-1.0.
|
36
|
+
BETTER_NMA-1.0.5.dist-info/METADATA,sha256=oqz6IAT6kY-q94bKDTGFxSxAXI47J1P_GWcZQXKeaxM,5100
|
37
|
+
BETTER_NMA-1.0.5.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
38
|
+
BETTER_NMA-1.0.5.dist-info/top_level.txt,sha256=SVRNqWPvCnynWVyXNAYnf9CSQIvMAvE6iyyiGHodQgY,11
|
39
|
+
BETTER_NMA-1.0.5.dist-info/RECORD,,
|
File without changes
|
File without changes
|