BETTER-NMA 1.0.2__tar.gz → 1.0.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/main.py +23 -3
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/nma_creator.py +3 -1
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/plot.py +5 -22
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/utilss/classes/preprocessing/batch_predictor.py +3 -1
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/utilss/classes/whitebox_testing.py +13 -6
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/utilss/wordnet_utils.py +123 -43
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/white_box_testing.py +1 -1
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA.egg-info/PKG-INFO +1 -1
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA.egg-info/SOURCES.txt +2 -1
- {better_nma-1.0.2 → better_nma-1.0.5}/PKG-INFO +1 -1
- {better_nma-1.0.2 → better_nma-1.0.5}/setup.py +1 -1
- better_nma-1.0.5/tests/test_main.py +280 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/__init__.py +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/adversarial_score.py +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/change_cluster_name.py +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/detect_attack.py +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/explaination_score.py +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/find_lca.py +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/query_image.py +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/train_adversarial_detector.py +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/utilss/__init__.py +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/utilss/classes/__init__.py +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/utilss/classes/adversarial_dataset.py +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/utilss/classes/adversarial_detector.py +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/utilss/classes/dendrogram.py +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/utilss/classes/edges_dataframe.py +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/utilss/classes/preprocessing/__init__.py +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/utilss/classes/preprocessing/graph_builder.py +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/utilss/classes/preprocessing/heap_processor.py +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/utilss/classes/preprocessing/hierarchical_clustering_builder.py +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/utilss/classes/preprocessing/tree_node.py +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/utilss/classes/preprocessing/z_builder.py +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/utilss/classes/score_calculator.py +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/utilss/enums/__init__.py +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/utilss/enums/explanation_method.py +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/utilss/enums/heap_types.py +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/utilss/models_utils.py +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/utilss/photos_uitls.py +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/utilss/photos_utils.py +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/utilss/verbal_explanation.py +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA.egg-info/dependency_links.txt +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA.egg-info/requires.txt +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA.egg-info/top_level.txt +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/README.md +0 -0
- {better_nma-1.0.2 → better_nma-1.0.5}/setup.cfg +0 -0
@@ -11,6 +11,8 @@ from .white_box_testing import analyze_white_box_results, get_white_box_analysis
|
|
11
11
|
from .explaination_score import get_explaination_score
|
12
12
|
from .adversarial_score import get_adversarial_score
|
13
13
|
from .find_lca import get_lca
|
14
|
+
from .utilss.wordnet_utils import synset_to_readable
|
15
|
+
import json
|
14
16
|
|
15
17
|
class NMA:
|
16
18
|
def __init__(self, x_train, y_train, labels, model, explanation_method, top_k=4, min_confidence=0.8, infinity=None, threshold=1e-6, save_connections=False, batch_size=32):
|
@@ -37,13 +39,12 @@ class NMA:
|
|
37
39
|
self.model = model
|
38
40
|
self.explanation_method = explanation_method
|
39
41
|
self.top_k = top_k
|
40
|
-
self.labels = labels
|
42
|
+
self.labels = [synset_to_readable(label) for label in labels]
|
41
43
|
self.min_confidence = min_confidence
|
42
44
|
self.infinity = infinity
|
43
45
|
self.threshold = threshold
|
44
46
|
self.save_connections = save_connections
|
45
47
|
self.batch_size = batch_size
|
46
|
-
self.labels = labels
|
47
48
|
self.detector = None
|
48
49
|
|
49
50
|
self.dendrogram_object, self.edges_df = preprocessing(x_train, y_train, labels, model, explanation_method, top_k, min_confidence, infinity, threshold, save_connections, batch_size)
|
@@ -82,6 +83,25 @@ class NMA:
|
|
82
83
|
"""
|
83
84
|
plot_sub_dendrogram(self.dendrogram_object.Z, self.labels, sub_labels, title=title, figsize=figsize)
|
84
85
|
|
86
|
+
|
87
|
+
def get_tree_as_dict(self, sub_labels=None):
|
88
|
+
"""
|
89
|
+
Returns the dendrogram hierarchy as a mutable Python dictionary.
|
90
|
+
|
91
|
+
Inputs:
|
92
|
+
- sub_labels (optional): List of labels to include in the subset.
|
93
|
+
|
94
|
+
Outputs: Dictionary representation of the dendrogram tree.
|
95
|
+
"""
|
96
|
+
if self.dendrogram_object is None:
|
97
|
+
raise ValueError("Dendrogram not available.")
|
98
|
+
|
99
|
+
if sub_labels is None:
|
100
|
+
sub_labels = self.labels
|
101
|
+
|
102
|
+
json_str = self.dendrogram_object.get_sub_dendrogram_formatted(sub_labels)
|
103
|
+
return json.loads(json_str)
|
104
|
+
|
85
105
|
## white box testing functions: ##
|
86
106
|
|
87
107
|
def white_box_testing(self, source_labels, target_labels, analyze_results=False, x_train=None, encode_images=True):
|
@@ -102,7 +122,7 @@ class NMA:
|
|
102
122
|
if self.edges_df is None:
|
103
123
|
raise ValueError("White box testing requires edges_df. Initialize NMA with save_connections=True")
|
104
124
|
|
105
|
-
whitebox = WhiteBoxTesting(self.model.name if hasattr(self.model, 'name') else "model")
|
125
|
+
whitebox = WhiteBoxTesting(self.model.name if hasattr(self.model, 'name') else "model", verbose=False)
|
106
126
|
problematic_imgs_dict = whitebox.find_problematic_images(
|
107
127
|
source_labels, target_labels, self.edges_df, self.explanation_method)
|
108
128
|
|
@@ -8,11 +8,13 @@ from .utilss.classes.preprocessing.graph_builder import GraphBuilder
|
|
8
8
|
from .utilss.classes.preprocessing.hierarchical_clustering_builder import HierarchicalClusteringBuilder
|
9
9
|
from .utilss.classes.preprocessing.z_builder import ZBuilder
|
10
10
|
from .utilss.classes.dendrogram import Dendrogram
|
11
|
+
from .utilss.wordnet_utils import synset_to_readable
|
11
12
|
|
12
13
|
def preprocessing(x_train, y_train, labels, model, explanation_method, top_k, min_confidence, infinity, threshold, save_connections, batch_size=32):
|
13
14
|
try:
|
14
15
|
X = x_train
|
15
|
-
y = y_train
|
16
|
+
y = [synset_to_readable(l) for l in y_train]
|
17
|
+
labels = [synset_to_readable(l) for l in labels]
|
16
18
|
|
17
19
|
graph = Graph(directed=False)
|
18
20
|
graph.add_vertices(labels)
|
@@ -101,31 +101,14 @@ def plot(nma_instance, sub_labels, title, figsize, **kwargs):
|
|
101
101
|
raise ValueError("No linkage matrix (z) found in NMA instance")
|
102
102
|
|
103
103
|
if sub_labels is None:
|
104
|
-
sub_labels
|
104
|
+
print("No sub_labels provided.")
|
105
|
+
return
|
105
106
|
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
print(filtered_dendrogram_json)
|
107
|
+
_ = nma_instance.dendrogram_object.get_sub_dendrogram_formatted(sub_labels)
|
108
|
+
# filtered_dendrogram_json = nma_instance.dendrogram_object.get_sub_dendrogram_formatted(
|
109
|
+
# sub_labels)
|
110
110
|
|
111
111
|
if hasattr(nma_instance, 'labels'):
|
112
112
|
plot_sub_dendrogram(nma_instance.dendrogram_object.Z,
|
113
113
|
nma_instance.labels, sub_labels, title, figsize)
|
114
114
|
|
115
|
-
print(nma_instance.dendrogram_object.Z)
|
116
|
-
|
117
|
-
# plt.figure(figsize=(20, 15))
|
118
|
-
# sch.dendrogram(
|
119
|
-
# nma_instance.z,
|
120
|
-
# leaf_rotation=0,
|
121
|
-
# leaf_font_size=10,
|
122
|
-
# orientation='right',
|
123
|
-
# color_threshold=85,
|
124
|
-
# **kwargs
|
125
|
-
# )
|
126
|
-
# plt.title('NMA Hierarchical Clustering Dendrogram')
|
127
|
-
# plt.xlabel('Elements')
|
128
|
-
# plt.ylabel('Distance')
|
129
|
-
# plt.tight_layout()
|
130
|
-
# plt.show()
|
131
|
-
|
{better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/utilss/classes/preprocessing/batch_predictor.py
RENAMED
@@ -1,5 +1,6 @@
|
|
1
1
|
import numpy as np
|
2
2
|
import tensorflow as tf
|
3
|
+
from ...wordnet_utils import synset_to_readable
|
3
4
|
|
4
5
|
|
5
6
|
class BatchPredictor:
|
@@ -18,7 +19,8 @@ class BatchPredictor:
|
|
18
19
|
valid_indices = [i for i in top_indices if i < len(labels)]
|
19
20
|
|
20
21
|
top_predictions = [
|
21
|
-
(i, labels[i], pred[i])
|
22
|
+
# (i, labels[i], pred[i])
|
23
|
+
(i, synset_to_readable(labels[i]), pred[i])
|
22
24
|
for i in valid_indices
|
23
25
|
if pred[i] >= graph_threshold
|
24
26
|
]
|
@@ -1,10 +1,11 @@
|
|
1
1
|
import pandas as pd
|
2
2
|
|
3
3
|
class WhiteBoxTesting:
|
4
|
-
def __init__(self, model_name):
|
4
|
+
def __init__(self, model_name, verbose=False):
|
5
5
|
self.model_name = model_name
|
6
6
|
self.problematic_img_ids = None
|
7
7
|
self.problematic_img_preds = None
|
8
|
+
self.verbose = verbose
|
8
9
|
|
9
10
|
|
10
11
|
def find_problematic_images(self, source_labels, target_labels, edges_df, explanation_method=None):
|
@@ -17,11 +18,15 @@ class WhiteBoxTesting:
|
|
17
18
|
(edges_df['source'].isin(target_labels)) &
|
18
19
|
(edges_df['target'].isin(source_labels))
|
19
20
|
]
|
20
|
-
|
21
|
+
|
22
|
+
if self.verbose:
|
23
|
+
print(filtered_edges_df_switched.head())
|
21
24
|
|
22
25
|
combined_filtered_edges_df = pd.concat([filtered_edges_df, filtered_edges_df_switched])
|
23
|
-
|
24
|
-
|
26
|
+
|
27
|
+
if self.verbose:
|
28
|
+
print("Combined filtered edges dataset:")
|
29
|
+
print(combined_filtered_edges_df)
|
25
30
|
|
26
31
|
unique_ids_list = combined_filtered_edges_df['image_id'].unique().tolist()
|
27
32
|
|
@@ -29,7 +34,9 @@ class WhiteBoxTesting:
|
|
29
34
|
image_id: list(zip(group['source'], group['target'], group['target_probability']))
|
30
35
|
for image_id, group in edges_df[edges_df['image_id'].isin(unique_ids_list)].groupby('image_id')
|
31
36
|
}
|
32
|
-
|
33
|
-
|
37
|
+
|
38
|
+
if self.verbose:
|
39
|
+
print("Matched dictionary:")
|
40
|
+
print(matched_dict)
|
34
41
|
|
35
42
|
return matched_dict
|
@@ -12,9 +12,32 @@ def folder_name_to_number(folder_name):
|
|
12
12
|
folder_number = 'n{:08d}'.format(offset)
|
13
13
|
return folder_number
|
14
14
|
|
15
|
+
def synset_to_readable(label):
|
16
|
+
# Check if label is in synset format
|
17
|
+
if isinstance(label, str) and label.startswith('n') and label[1:].isdigit():
|
18
|
+
special_cases = {
|
19
|
+
"n02012849": "crane bird", # Bird
|
20
|
+
"n03126707": "crane machine", # Vehicle
|
21
|
+
"n03710637": "maillot", # Swimsuit
|
22
|
+
"n03710721": "tank suit" # Swimsuit
|
23
|
+
}
|
24
|
+
|
25
|
+
if label in special_cases:
|
26
|
+
return special_cases[label]
|
27
|
+
|
28
|
+
try:
|
29
|
+
offset = int(label[1:])
|
30
|
+
synset = wn.synset_from_pos_and_offset('n', offset)
|
31
|
+
return synset.lemma_names()[0].replace('_', ' ')
|
32
|
+
except Exception:
|
33
|
+
return label # fallback if not found
|
34
|
+
else:
|
35
|
+
return label # already readable
|
36
|
+
|
15
37
|
def common_group(groups):
|
16
38
|
common_hypernyms = []
|
17
39
|
hierarchy = {}
|
40
|
+
|
18
41
|
for group in groups:
|
19
42
|
hierarchy[group] = []
|
20
43
|
synsets = wn.synsets(group)
|
@@ -29,48 +52,118 @@ def common_group(groups):
|
|
29
52
|
for hypernym in hierarchy[groups.pop()]:
|
30
53
|
if all(hypernym in hypernyms for hypernyms in hierarchy.values()):
|
31
54
|
common_hypernyms.append(hypernym)
|
55
|
+
|
32
56
|
return common_hypernyms[::-1]
|
33
57
|
|
58
|
+
def process_hierarchy(hierarchy_data,):
|
59
|
+
"""Process the entire hierarchy, renaming clusters while preserving structure."""
|
60
|
+
return _rename_clusters(hierarchy_data)
|
34
61
|
|
35
62
|
def get_all_leaf_names(node):
|
63
|
+
"""Extract all leaf node names from a cluster hierarchy."""
|
36
64
|
if "children" not in node:
|
65
|
+
# Only return actual object names, not cluster names
|
37
66
|
if "cluster" not in node["name"]:
|
38
67
|
return [node["name"]]
|
39
68
|
return []
|
69
|
+
|
40
70
|
names = []
|
41
71
|
for child in node["children"]:
|
42
72
|
names.extend(get_all_leaf_names(child))
|
43
73
|
return names
|
74
|
+
|
75
|
+
def _rename_clusters(tree):
|
76
|
+
"""
|
77
|
+
Traverse the tree in BFS manner and rename clusters based on child names,
|
78
|
+
which can be leaves or already-renamed clusters.
|
79
|
+
"""
|
80
|
+
used_names = set()
|
81
|
+
all_leaf_names = {leaf.lower() for leaf in get_all_leaf_names(tree)}
|
44
82
|
|
83
|
+
queue = deque()
|
84
|
+
queue.append(tree)
|
45
85
|
|
46
|
-
|
47
|
-
|
86
|
+
# BFS traversal, we store nodes with children in postprocess queue
|
87
|
+
postprocess_nodes = []
|
48
88
|
|
89
|
+
while queue:
|
90
|
+
node = queue.popleft()
|
91
|
+
if "children" in node:
|
92
|
+
queue.extend(node["children"])
|
93
|
+
postprocess_nodes.append(node) # non-leaf clusters to process after children
|
49
94
|
|
50
|
-
|
95
|
+
# Process clusters in reverse BFS (bottom-up)
|
96
|
+
for node in reversed(postprocess_nodes):
|
97
|
+
if "cluster" not in node["name"]:
|
98
|
+
continue # already renamed
|
99
|
+
|
100
|
+
# Collect child names (renamed or original leaves)
|
101
|
+
child_names = [child["name"] for child in node["children"] if "name" in child]
|
102
|
+
|
103
|
+
# Get hypernym candidate from child names
|
104
|
+
candidate = find_common_hypernyms(child_names)
|
105
|
+
if candidate:
|
106
|
+
# Ensure it’s unique
|
107
|
+
base = candidate
|
108
|
+
unique = base
|
109
|
+
idx = 1
|
110
|
+
while unique.lower() in all_leaf_names or unique.lower() in {n.lower() for n in used_names}:
|
111
|
+
idx += 1
|
112
|
+
unique = f"{base}_{idx}"
|
113
|
+
node["name"] = unique
|
114
|
+
used_names.add(unique)
|
115
|
+
|
116
|
+
return tree
|
117
|
+
|
118
|
+
def _get_top_synsets(
|
119
|
+
phrase: str,
|
120
|
+
pos=wn.NOUN,
|
121
|
+
max_senses: int = 15
|
122
|
+
) -> list[wn.synset]:
|
123
|
+
"""
|
124
|
+
Return up to `max_senses` synsets for `phrase`.
|
125
|
+
- Replaces spaces/underscores so WordNet can match “pickup truck” or “aquarium_fish”.
|
126
|
+
- WordNet already orders synsets by frequency, so we take only the first few.
|
127
|
+
"""
|
51
128
|
lemma = phrase.strip().lower().replace(" ", "_")
|
52
129
|
syns = wn.synsets(lemma, pos=pos)
|
53
130
|
return syns[:max_senses] if syns else []
|
54
131
|
|
55
132
|
|
133
|
+
# ---------------------------------------------------
|
134
|
+
# Core: compute the single best hypernym for a set of words
|
135
|
+
# ---------------------------------------------------
|
56
136
|
def _find_best_common_hypernym(
|
57
137
|
leaves: list[str],
|
58
138
|
max_senses_per_word: int = 5,
|
59
|
-
banned_lemmas: set[str] = None
|
60
|
-
) ->
|
61
|
-
|
139
|
+
banned_lemmas: set[str] = None,
|
140
|
+
) -> str | None:
|
141
|
+
"""
|
142
|
+
1. For each leaf in `leaves`, fetch up to `max_senses_per_word` synsets.
|
143
|
+
2. For EVERY pair of leaves (w1, w2), for EVERY combination of synset ∈ synsets(w1) × synsets(w2),
|
144
|
+
call syn1.lowest_common_hypernyms(syn2) → yields a list of shared hypernyms.
|
145
|
+
Tally them in `lch_counter`.
|
146
|
+
3. Sort the candidates by (frequency, min_depth) so we pick the most-specific, most-common ancestor.
|
147
|
+
4. Filter out overly generic lemmas (like “entity”, “object”) unless NOTHING else remains.
|
148
|
+
5. Return the best lemma_name (underscore → space, capitalized).
|
149
|
+
"""
|
62
150
|
if banned_lemmas is None:
|
63
151
|
banned_lemmas = {"entity", "object", "physical_entity", "thing", "Object", "Whole", "Whole", "Physical_entity", "Thing", "Entity", "Artifact"}
|
64
|
-
|
152
|
+
|
153
|
+
|
154
|
+
# 1. Map each leaf → up to `max_senses_per_word` synsets
|
65
155
|
word_to_synsets: dict[str, list[wn.synset]] = {}
|
66
156
|
for w in leaves:
|
67
157
|
syns = _get_top_synsets(w, wn.NOUN, max_senses_per_word)
|
68
158
|
if syns:
|
69
159
|
word_to_synsets[w] = syns
|
70
160
|
|
161
|
+
# If fewer than 2 words have ANY synsets, we cannot get a meaningful common hypernym
|
71
162
|
if len(word_to_synsets) < 2:
|
72
163
|
return None
|
73
164
|
|
165
|
+
# 2. For each pair of distinct leaves w1, w2, do ALL combinations of synset₁ × synset₂
|
166
|
+
# and tally lowest_common_hypernyms
|
74
167
|
lch_counter: Counter[wn.synset] = Counter()
|
75
168
|
words_list = list(word_to_synsets.keys())
|
76
169
|
|
@@ -82,7 +175,6 @@ def _find_best_common_hypernym(
|
|
82
175
|
try:
|
83
176
|
common = s1.lowest_common_hypernyms(s2)
|
84
177
|
except Exception as e:
|
85
|
-
print(f"Error computing LCH({s1.name()}, {s2.name()}): {e}")
|
86
178
|
continue
|
87
179
|
for hyp in common:
|
88
180
|
lch_counter[hyp] += 1
|
@@ -90,12 +182,14 @@ def _find_best_common_hypernym(
|
|
90
182
|
if not lch_counter:
|
91
183
|
return None
|
92
184
|
|
185
|
+
# 3. Sort candidates by (frequency, min_depth) descending
|
93
186
|
candidates = sorted(
|
94
187
|
lch_counter.items(),
|
95
188
|
key=lambda item: (item[1], item[0].min_depth()),
|
96
189
|
reverse=True
|
97
190
|
)
|
98
191
|
|
192
|
+
# 4. Filter out generic lemma_names unless NOTHING else remains
|
99
193
|
filtered: list[tuple[wn.synset, int]] = []
|
100
194
|
for syn, freq in candidates:
|
101
195
|
lemma = syn.name().split(".")[0].lower()
|
@@ -103,75 +197,61 @@ def _find_best_common_hypernym(
|
|
103
197
|
continue
|
104
198
|
filtered.append((syn, freq))
|
105
199
|
|
200
|
+
# If every candidate was filtered out, allow the first generic anyway
|
106
201
|
if not filtered:
|
107
202
|
filtered = candidates
|
108
203
|
|
109
204
|
best_synset, best_freq = filtered[0]
|
110
205
|
best_label = (best_synset.name().split(".")[0].replace(" ", "_")).lower()
|
206
|
+
|
111
207
|
return best_label
|
112
208
|
|
209
|
+
|
210
|
+
# ---------------------------------------------------
|
211
|
+
# Public version: branching on single vs. multiple leaves
|
212
|
+
# ---------------------------------------------------
|
113
213
|
def find_common_hypernyms(
|
114
214
|
words: list[str],
|
115
215
|
abstraction_level: int = 0,
|
116
|
-
) ->
|
117
|
-
|
216
|
+
) -> str | None:
|
217
|
+
"""
|
218
|
+
Improved drop-in replacement for your old `find_common_hypernyms`.
|
219
|
+
1. Normalize each word (underscores ↔ spaces, lowercase) and filter out anything containing "Cluster".
|
220
|
+
2. If there’s exactly one valid leaf, pick its first hypernym (one level up) unless it’s “entity”.
|
221
|
+
3. If there are 2+ leaves, call _find_best_common_hypernym on them.
|
222
|
+
"""
|
223
|
+
|
118
224
|
clean_leaves = [
|
225
|
+
# w.strip().lower().replace(" ", "_")
|
119
226
|
re.sub(r'_\d+$', '', w.strip().lower().replace(" ", "_"))
|
120
227
|
for w in words
|
121
228
|
if w and "cluster" not in w.lower()
|
122
229
|
]
|
123
230
|
|
231
|
+
# If nothing remains, bail out
|
124
232
|
if not clean_leaves:
|
125
233
|
return None
|
126
234
|
|
235
|
+
# Single-word case: pick its immediate hypernym (second-to-bottom in the hypernym path)
|
127
236
|
if len(clean_leaves) == 1:
|
128
237
|
word = clean_leaves[0]
|
129
238
|
synsets = _get_top_synsets(word, wn.NOUN, max_senses=10)
|
130
239
|
if not synsets:
|
131
240
|
return None
|
132
241
|
|
242
|
+
# Choose the first sense’s longest hypernym path, then take one level up from leaf sense.
|
133
243
|
paths = synsets[0].hypernym_paths() # list of lists
|
134
244
|
if not paths:
|
135
245
|
return None
|
136
246
|
|
137
247
|
longest_path = max(paths, key=lambda p: len(p))
|
248
|
+
# If path has at least 2 nodes, candidate = one level above the leaf sense
|
138
249
|
if len(longest_path) >= 2:
|
139
250
|
candidate = longest_path[-2]
|
140
251
|
name = (candidate.name().split(".")[0].replace(" ", "_")).lower()
|
141
252
|
if name.lower() not in {word, "entity"}:
|
142
253
|
return name
|
143
254
|
return None
|
144
|
-
return _find_best_common_hypernym(clean_leaves, max_senses_per_word=5)
|
145
|
-
|
146
|
-
|
147
|
-
def _rename_clusters(tree):
|
148
|
-
used_names = set()
|
149
|
-
all_leaf_names = {leaf.lower() for leaf in get_all_leaf_names(tree)}
|
150
|
-
queue = deque()
|
151
|
-
queue.append(tree)
|
152
|
-
postprocess_nodes = []
|
153
|
-
|
154
|
-
while queue:
|
155
|
-
node = queue.popleft()
|
156
|
-
if "children" in node:
|
157
|
-
queue.extend(node["children"])
|
158
|
-
postprocess_nodes.append(node)
|
159
|
-
|
160
|
-
for node in reversed(postprocess_nodes):
|
161
|
-
if "cluster" not in node["name"]:
|
162
|
-
continue
|
163
255
|
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
if candidate:
|
168
|
-
base = candidate
|
169
|
-
unique = base
|
170
|
-
idx = 1
|
171
|
-
while unique.lower() in all_leaf_names or unique.lower() in {n.lower() for n in used_names}:
|
172
|
-
idx += 1
|
173
|
-
unique = f"{base}_{idx}"
|
174
|
-
node["name"] = unique
|
175
|
-
used_names.add(unique)
|
176
|
-
|
177
|
-
return tree
|
256
|
+
# 2+ leaves: use pairwise LCA approach
|
257
|
+
return _find_best_common_hypernym(clean_leaves, max_senses_per_word=5)
|
@@ -57,7 +57,7 @@ def get_white_box_analysis(edges_df_path, model_filename, dataset_str, source_la
|
|
57
57
|
edges_data.load_dataframe()
|
58
58
|
df = edges_data.get_dataframe()
|
59
59
|
|
60
|
-
whitebox_testing = WhiteBoxTesting(model_filename)
|
60
|
+
whitebox_testing = WhiteBoxTesting(model_filename, verbose=False)
|
61
61
|
problematic_imgs_dict = whitebox_testing.find_problematic_images(
|
62
62
|
source_labels, target_labels, df, dataset_str)
|
63
63
|
|
@@ -39,4 +39,5 @@ BETTER_NMA/utilss/classes/preprocessing/tree_node.py
|
|
39
39
|
BETTER_NMA/utilss/classes/preprocessing/z_builder.py
|
40
40
|
BETTER_NMA/utilss/enums/__init__.py
|
41
41
|
BETTER_NMA/utilss/enums/explanation_method.py
|
42
|
-
BETTER_NMA/utilss/enums/heap_types.py
|
42
|
+
BETTER_NMA/utilss/enums/heap_types.py
|
43
|
+
tests/test_main.py
|
@@ -8,7 +8,7 @@ except FileNotFoundError:
|
|
8
8
|
|
9
9
|
setup(
|
10
10
|
name="BETTER_NMA",
|
11
|
-
version="1.0.
|
11
|
+
version="1.0.5",
|
12
12
|
author="BETTER_XAI",
|
13
13
|
author_email="BETTERXAI2025@gmail.com",
|
14
14
|
description="NMA: Dendrogram-based model analysis, white-box testing, and adversarial detection",
|
@@ -0,0 +1,280 @@
|
|
1
|
+
"""
|
2
|
+
Test script for BETTER_NMA package main functionalities with CIFAR-100
|
3
|
+
Testing: nma.plot, plot_sub_dendrogram, get_tree_as_dict, white_box_testing, find_lca
|
4
|
+
"""
|
5
|
+
|
6
|
+
import sys
|
7
|
+
import os
|
8
|
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
9
|
+
|
10
|
+
from BETTER_NMA import NMA
|
11
|
+
import numpy as np
|
12
|
+
import tensorflow as tf
|
13
|
+
from tensorflow.keras.applications.resnet50 import preprocess_input
|
14
|
+
import json
|
15
|
+
import matplotlib.pyplot as plt
|
16
|
+
|
17
|
+
# Suppress TensorFlow warnings
|
18
|
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
|
19
|
+
|
20
|
+
print("TF version:", tf.__version__)
|
21
|
+
|
22
|
+
def test_cifar100_nma():
|
23
|
+
"""Test all NMA functionalities with CIFAR-100"""
|
24
|
+
|
25
|
+
print("="*60)
|
26
|
+
print("Testing NMA with CIFAR-100 Dataset")
|
27
|
+
print("="*60)
|
28
|
+
|
29
|
+
# 1. Load CIFAR-100 dataset
|
30
|
+
print("\n1. Loading CIFAR-100 dataset...")
|
31
|
+
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data()
|
32
|
+
|
33
|
+
labels = [
|
34
|
+
'apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle',
|
35
|
+
'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle',
|
36
|
+
'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur',
|
37
|
+
'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard',
|
38
|
+
'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain',
|
39
|
+
'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree',
|
40
|
+
'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket',
|
41
|
+
'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider',
|
42
|
+
'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor',
|
43
|
+
'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree', 'wolf', 'woman', 'worm'
|
44
|
+
]
|
45
|
+
|
46
|
+
# Use a small subset for faster testing (first 500 samples)
|
47
|
+
x_train = x_train[:500]
|
48
|
+
y_train = y_train[:500]
|
49
|
+
|
50
|
+
# Preprocess data
|
51
|
+
x_train = preprocess_input(x_train)
|
52
|
+
y_train = y_train.astype(int).flatten()
|
53
|
+
y_train_strings = [labels[i] for i in y_train]
|
54
|
+
|
55
|
+
print(f"x_train shape: {x_train.shape}")
|
56
|
+
print(f"y_train_strings example: {y_train_strings[:5]}")
|
57
|
+
|
58
|
+
# 2. Load or create model
|
59
|
+
print("\n2. Loading CIFAR-100 model...")
|
60
|
+
|
61
|
+
# Check if model exists
|
62
|
+
model_path = "tests/cifar100_resnet.keras"
|
63
|
+
if os.path.exists(model_path):
|
64
|
+
try:
|
65
|
+
cifar100_model = tf.keras.models.load_model(model_path)
|
66
|
+
print(f"Loaded model from: {model_path}")
|
67
|
+
except Exception as e:
|
68
|
+
print(f"Could not load saved model: {e}")
|
69
|
+
# Create simple model for testing
|
70
|
+
cifar100_model = tf.keras.Sequential([
|
71
|
+
tf.keras.layers.Input(shape=(32, 32, 3)),
|
72
|
+
tf.keras.layers.Conv2D(32, 3, activation='relu'),
|
73
|
+
tf.keras.layers.GlobalAveragePooling2D(),
|
74
|
+
tf.keras.layers.Dense(100, activation='softmax')
|
75
|
+
])
|
76
|
+
cifar100_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
|
77
|
+
print("Created simple test model")
|
78
|
+
else:
|
79
|
+
print("Creating simple model for testing...")
|
80
|
+
cifar100_model = tf.keras.Sequential([
|
81
|
+
tf.keras.layers.Input(shape=(32, 32, 3)),
|
82
|
+
tf.keras.layers.Conv2D(32, 3, activation='relu'),
|
83
|
+
tf.keras.layers.GlobalAveragePooling2D(),
|
84
|
+
tf.keras.layers.Dense(100, activation='softmax')
|
85
|
+
])
|
86
|
+
cifar100_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
|
87
|
+
print("Created simple test model")
|
88
|
+
|
89
|
+
# 3. Initialize NMA with similarity explanation method
|
90
|
+
print("\n3. Initializing NMA...")
|
91
|
+
try:
|
92
|
+
nma = NMA(
|
93
|
+
x_train=x_train,
|
94
|
+
y_train=y_train_strings,
|
95
|
+
labels=labels,
|
96
|
+
model=cifar100_model,
|
97
|
+
explanation_method="similarity",
|
98
|
+
top_k=4,
|
99
|
+
min_confidence=0.8,
|
100
|
+
batch_size=32,
|
101
|
+
save_connections=True # Required for white-box testing
|
102
|
+
)
|
103
|
+
print("NMA initialized successfully")
|
104
|
+
except Exception as e:
|
105
|
+
print(f"Error initializing NMA: {e}")
|
106
|
+
return None
|
107
|
+
|
108
|
+
print("\n" + "="*60)
|
109
|
+
print("TESTING NMA FUNCTIONALITIES")
|
110
|
+
print("="*60)
|
111
|
+
|
112
|
+
# Test 1: nma.plot() - Full dendrogram
|
113
|
+
print("\n📊 Test 1: nma.plot() - Full dendrogram")
|
114
|
+
print("-"*40)
|
115
|
+
try:
|
116
|
+
nma.plot(title="CIFAR-100 Full Dendrogram", figsize=(20, 20))
|
117
|
+
print("✓ Full dendrogram plotted")
|
118
|
+
plt.close('all') # Close plots to save memory
|
119
|
+
except Exception as e:
|
120
|
+
print(f"✗ Error: {e}")
|
121
|
+
|
122
|
+
# Test 2: nma.plot() with sub_labels
|
123
|
+
print("\n📊 Test 2: nma.plot() with sub_labels")
|
124
|
+
print("-"*40)
|
125
|
+
try:
|
126
|
+
# Test with tree-related labels
|
127
|
+
tree_labels = ["maple_tree", "oak_tree", "palm_tree", "pine_tree", "willow_tree", "forest"]
|
128
|
+
nma.plot(sub_labels=tree_labels, title="Tree Classes Sub-Dendrogram", figsize=(12, 8))
|
129
|
+
print(f"✓ Sub-dendrogram plotted for: {tree_labels}")
|
130
|
+
plt.close('all')
|
131
|
+
except Exception as e:
|
132
|
+
print(f"✗ Error: {e}")
|
133
|
+
|
134
|
+
# Test 3: nma.plot_sub_dendrogram()
|
135
|
+
print("\n📊 Test 3: nma.plot_sub_dendrogram()")
|
136
|
+
print("-"*40)
|
137
|
+
try:
|
138
|
+
# Test with people-related labels
|
139
|
+
people_labels = ["baby", "boy", "girl", "man", "woman"]
|
140
|
+
nma.plot_sub_dendrogram(sub_labels=people_labels, title="People Classes", figsize=(10, 6))
|
141
|
+
print(f"✓ plot_sub_dendrogram worked for: {people_labels}")
|
142
|
+
plt.close('all')
|
143
|
+
except Exception as e:
|
144
|
+
print(f"✗ Error: {e}")
|
145
|
+
|
146
|
+
# Test 4: nma.get_tree_as_dict()
|
147
|
+
print("\n📋 Test 4: nma.get_tree_as_dict()")
|
148
|
+
print("-"*40)
|
149
|
+
try:
|
150
|
+
# Full tree
|
151
|
+
tree_dict = nma.get_tree_as_dict()
|
152
|
+
print("✓ Got full tree as dictionary")
|
153
|
+
print(f" Keys: {list(tree_dict.keys())}")
|
154
|
+
if 'name' in tree_dict:
|
155
|
+
print(f" Root name: {tree_dict['name']}")
|
156
|
+
|
157
|
+
# Sub-tree with animal labels
|
158
|
+
animal_labels = ["bear", "beaver", "bee", "beetle", "butterfly"]
|
159
|
+
sub_tree_dict = nma.get_tree_as_dict(sub_labels=animal_labels)
|
160
|
+
print(f"✓ Got sub-tree for: {animal_labels}")
|
161
|
+
|
162
|
+
# Show structure
|
163
|
+
tree_json = json.dumps(sub_tree_dict, indent=2)
|
164
|
+
print(f" Sub-tree preview (first 200 chars): {tree_json[:200]}...")
|
165
|
+
except Exception as e:
|
166
|
+
print(f"✗ Error: {e}")
|
167
|
+
|
168
|
+
# Test 5: nma.find_lca()
|
169
|
+
print("\n🔍 Test 5: nma.find_lca() - Finding Lowest Common Ancestors")
|
170
|
+
print("-"*40)
|
171
|
+
|
172
|
+
test_pairs = [
|
173
|
+
("woman", "girl"), # People/female cluster
|
174
|
+
("man", "boy"), # People/male cluster
|
175
|
+
("maple_tree", "oak_tree"), # Tree cluster
|
176
|
+
("bee", "beetle"), # Insect cluster
|
177
|
+
("apple", "pear"), # Fruit cluster
|
178
|
+
("tulip", "orchid"), # Flower cluster
|
179
|
+
]
|
180
|
+
|
181
|
+
for label1, label2 in test_pairs:
|
182
|
+
try:
|
183
|
+
lca = nma.find_lca(label1, label2)
|
184
|
+
print(f"✓ LCA of '{label1}' and '{label2}': {lca}")
|
185
|
+
except Exception as e:
|
186
|
+
print(f"✗ Error finding LCA for {label1}-{label2}: {e}")
|
187
|
+
|
188
|
+
# Test 6: nma.white_box_testing()
|
189
|
+
print("\n🧪 Test 6: nma.white_box_testing()")
|
190
|
+
print("-"*40)
|
191
|
+
try:
|
192
|
+
# Test as in Kaggle example
|
193
|
+
source_labels = ["beetle", "tulip"]
|
194
|
+
target_labels = ["bee", "orchid"]
|
195
|
+
|
196
|
+
print(f" Testing: {source_labels} → {target_labels}")
|
197
|
+
|
198
|
+
# Without analysis
|
199
|
+
problematic_imgs = nma.white_box_testing(
|
200
|
+
source_labels=source_labels,
|
201
|
+
target_labels=target_labels,
|
202
|
+
analyze_results=False
|
203
|
+
)
|
204
|
+
|
205
|
+
print(f"✓ White-box testing completed")
|
206
|
+
print(f" Found {len(problematic_imgs)} problematic images")
|
207
|
+
|
208
|
+
if problematic_imgs:
|
209
|
+
# Show first problematic image
|
210
|
+
img_id = list(problematic_imgs.keys())[0]
|
211
|
+
matches = problematic_imgs[img_id]
|
212
|
+
print(f" Example - Image {img_id}: {len(matches)} matches")
|
213
|
+
for match in matches[:3]:
|
214
|
+
print(f" {match[0]} → {match[1]}: {match[2]:.4f}")
|
215
|
+
|
216
|
+
# With analysis
|
217
|
+
analyzed_results = nma.white_box_testing(
|
218
|
+
source_labels=source_labels,
|
219
|
+
target_labels=target_labels,
|
220
|
+
analyze_results=True,
|
221
|
+
x_train=x_train,
|
222
|
+
encode_images=False
|
223
|
+
)
|
224
|
+
print(f"✓ Analysis completed: {len(analyzed_results)} results")
|
225
|
+
|
226
|
+
except Exception as e:
|
227
|
+
print(f"✗ Error: {e}")
|
228
|
+
|
229
|
+
# Test 7: nma.get_white_box_analysis()
|
230
|
+
print("\n🧪 Test 7: nma.get_white_box_analysis()")
|
231
|
+
print("-"*40)
|
232
|
+
try:
|
233
|
+
source_labels = ["woman", "girl"]
|
234
|
+
target_labels = ["man", "boy"]
|
235
|
+
|
236
|
+
print(f" Testing: {source_labels} → {target_labels}")
|
237
|
+
|
238
|
+
analysis = nma.get_white_box_analysis(
|
239
|
+
source_labels=source_labels,
|
240
|
+
target_labels=target_labels,
|
241
|
+
x_train=x_train
|
242
|
+
)
|
243
|
+
|
244
|
+
print(f"✓ Analysis completed: {len(analysis)} entries")
|
245
|
+
if analysis:
|
246
|
+
print(f" Entry keys: {list(analysis[0].keys())}")
|
247
|
+
|
248
|
+
except Exception as e:
|
249
|
+
print(f"✗ Error: {e}")
|
250
|
+
|
251
|
+
print("\n" + "="*60)
|
252
|
+
print("TEST SUMMARY")
|
253
|
+
print("="*60)
|
254
|
+
|
255
|
+
print("\n✅ Tested NMA functionalities:")
|
256
|
+
print(" 1. nma.plot() - Full and sub dendrograms")
|
257
|
+
print(" 2. nma.plot_sub_dendrogram() - Specific label subsets")
|
258
|
+
print(" 3. nma.get_tree_as_dict() - Tree structure as dictionary")
|
259
|
+
print(" 4. nma.find_lca() - Finding lowest common ancestors")
|
260
|
+
print(" 5. nma.white_box_testing() - Identifying problematic images")
|
261
|
+
print(" 6. nma.get_white_box_analysis() - Detailed analysis")
|
262
|
+
|
263
|
+
return nma
|
264
|
+
|
265
|
+
if __name__ == "__main__":
|
266
|
+
try:
|
267
|
+
print("Starting CIFAR-100 NMA tests...")
|
268
|
+
print("Using subset of 500 samples for faster testing\n")
|
269
|
+
|
270
|
+
nma = test_cifar100_nma()
|
271
|
+
|
272
|
+
if nma:
|
273
|
+
print("\n✅ All tests completed successfully!")
|
274
|
+
else:
|
275
|
+
print("\n⚠ Tests completed with issues")
|
276
|
+
|
277
|
+
except Exception as e:
|
278
|
+
print(f"\n❌ Fatal error: {e}")
|
279
|
+
import traceback
|
280
|
+
traceback.print_exc()
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
{better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/utilss/classes/preprocessing/graph_builder.py
RENAMED
File without changes
|
{better_nma-1.0.2 → better_nma-1.0.5}/BETTER_NMA/utilss/classes/preprocessing/heap_processor.py
RENAMED
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|
File without changes
|