BETTER-NMA 1.0.1__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- BETTER_NMA/explaination_score.py +52 -0
- BETTER_NMA/main.py +5 -0
- BETTER_NMA/utilss/classes/adversarial_dataset.py +4 -9
- BETTER_NMA/utilss/classes/preprocessing/batch_predictor.py +1 -1
- BETTER_NMA/utilss/classes/score_calculator.py +0 -1
- BETTER_NMA/utilss/photos_uitls.py +3 -3
- BETTER_NMA/utilss/photos_utils.py +3 -3
- {BETTER_NMA-1.0.1.dist-info → BETTER_NMA-1.0.2.dist-info}/METADATA +1 -1
- {BETTER_NMA-1.0.1.dist-info → BETTER_NMA-1.0.2.dist-info}/RECORD +11 -10
- {BETTER_NMA-1.0.1.dist-info → BETTER_NMA-1.0.2.dist-info}/WHEEL +0 -0
- {BETTER_NMA-1.0.1.dist-info → BETTER_NMA-1.0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,52 @@
|
|
1
|
+
from .utilss.classes.score_calculator import ScoreCalculator
|
2
|
+
from itertools import combinations
|
3
|
+
|
4
|
+
def get_explaination_score(dendrogram, class_names, normalize=True):
|
5
|
+
"""
|
6
|
+
Get the score of the entire dendrogram based on pairwise LCA ancestor counts.
|
7
|
+
|
8
|
+
Parameters:
|
9
|
+
- dendrogram: The hierarchical clustering dendrogram
|
10
|
+
- class_names: List of class names corresponding to model outputs
|
11
|
+
|
12
|
+
Returns:
|
13
|
+
- score: The explanation score of the dendrogram
|
14
|
+
"""
|
15
|
+
score_calculator = ScoreCalculator(dendrogram.Z, class_names)
|
16
|
+
|
17
|
+
total_count = 0
|
18
|
+
num_classes = len(class_names)
|
19
|
+
|
20
|
+
# Get all combinations of 2 labels from class_names
|
21
|
+
for label1, label2 in combinations(class_names, 2):
|
22
|
+
try:
|
23
|
+
idx1 = class_names.index(label1)
|
24
|
+
idx2 = class_names.index(label2)
|
25
|
+
count, _ = score_calculator.count_ancestors_to_lca(idx1, idx2)
|
26
|
+
total_count += count
|
27
|
+
except ValueError as e:
|
28
|
+
print(f"Error processing pair ({label1}, {label2}): {e}")
|
29
|
+
continue
|
30
|
+
|
31
|
+
total_combinations = len(list(combinations(class_names, 2)))
|
32
|
+
print(f"Total combinations processed: {total_combinations}")
|
33
|
+
print(f"Total ancestor count sum: {total_count}")
|
34
|
+
|
35
|
+
if normalize:
|
36
|
+
# Maximum ancestors = height of dendrogram tree
|
37
|
+
max_ancestors_per_pair = num_classes - 1 # Maximum tree height
|
38
|
+
theoretical_max = total_combinations * max_ancestors_per_pair
|
39
|
+
|
40
|
+
if theoretical_max == 0:
|
41
|
+
normalized_score = 0
|
42
|
+
else:
|
43
|
+
# Invert the score so higher ancestor counts = lower explanation quality
|
44
|
+
# and normalize to 0-100%
|
45
|
+
normalized_score = max(0, (1 - (total_count / theoretical_max)) * 100)
|
46
|
+
|
47
|
+
print(f"Theoretical maximum: {theoretical_max}")
|
48
|
+
print(f"Normalized score: {normalized_score:.2f}%")
|
49
|
+
|
50
|
+
return normalized_score
|
51
|
+
else:
|
52
|
+
return total_count
|
BETTER_NMA/main.py
CHANGED
@@ -8,6 +8,7 @@ from .detect_attack import detect_adversarial_image
|
|
8
8
|
from .query_image import query_image
|
9
9
|
from .utilss.verbal_explanation import get_verbal_explanation
|
10
10
|
from .white_box_testing import analyze_white_box_results, get_white_box_analysis
|
11
|
+
from .explaination_score import get_explaination_score
|
11
12
|
from .adversarial_score import get_adversarial_score
|
12
13
|
from .find_lca import get_lca
|
13
14
|
|
@@ -212,6 +213,10 @@ class NMA:
|
|
212
213
|
score = get_adversarial_score(image, self.model, self.dendrogram_object.Z, self.labels, top_k=top_k)
|
213
214
|
return score
|
214
215
|
|
216
|
+
def explanation_score(self, normalize=True):
|
217
|
+
return get_explaination_score(self.dendrogram_object, self.labels, normalize=normalize)
|
218
|
+
|
219
|
+
|
215
220
|
## query and explanation functions: ##
|
216
221
|
|
217
222
|
def query_image(self, image, top_k=5):
|
@@ -13,32 +13,27 @@ class AdversarialDataset:
|
|
13
13
|
scores = []
|
14
14
|
labels = []
|
15
15
|
|
16
|
-
print("getting preprocess function...")
|
17
|
-
|
18
16
|
try:
|
19
17
|
for image in self.clear_images[:50]:
|
20
18
|
# Add batch dimension for model prediction
|
21
19
|
image_batch = np.expand_dims(image, axis=0)
|
22
|
-
score = self.score_calculator.calculate_adversarial_score(self.model.predict(image_batch))
|
20
|
+
score = self.score_calculator.calculate_adversarial_score(self.model.predict(image_batch, verbose=0))
|
23
21
|
scores.append(score)
|
24
22
|
labels.append(0)
|
25
23
|
except Exception as e:
|
26
|
-
print(f"Error processing clean
|
24
|
+
print(f"Error processing clean images: {e}")
|
27
25
|
|
28
26
|
# Generate features for PGD attacks
|
29
|
-
print("Generating attack features...")
|
30
27
|
try:
|
31
28
|
for adv_image in self.adversarial_images[:50]:
|
32
29
|
# Add batch dimension for model prediction
|
33
30
|
adv_image_batch = np.expand_dims(adv_image, axis=0)
|
34
|
-
score = self.score_calculator.calculate_adversarial_score(self.model.predict(adv_image_batch))
|
31
|
+
score = self.score_calculator.calculate_adversarial_score(self.model.predict(adv_image_batch, verbose=0))
|
35
32
|
scores.append(score)
|
36
33
|
labels.append(1)
|
37
34
|
except Exception as e:
|
38
|
-
print(f"Error processing
|
35
|
+
print(f"Error processing attacked images: {e}")
|
39
36
|
|
40
|
-
print("labels:", labels)
|
41
|
-
print("scores:", scores)
|
42
37
|
|
43
38
|
# Convert to numpy arrays
|
44
39
|
X = np.array(scores)
|
@@ -11,7 +11,7 @@ class BatchPredictor:
|
|
11
11
|
self.buffer_results = [] # To store batch results
|
12
12
|
|
13
13
|
def get_top_predictions(self, X, labels, top_k, graph_threshold):
|
14
|
-
batch_preds = self.model.predict(np.array(X))
|
14
|
+
batch_preds = self.model.predict(np.array(X), verbose=0)
|
15
15
|
batch_results = []
|
16
16
|
for pred in batch_preds:
|
17
17
|
top_indices = pred.argsort()[-top_k:][::-1]
|
@@ -134,7 +134,6 @@ class ScoreCalculator:
|
|
134
134
|
|
135
135
|
# Calculate semantic distance
|
136
136
|
rank_count, _ = self.count_ancestors_to_lca(idx1, idx2)
|
137
|
-
print(f"Rank Count for {label1} and {label2}: {rank_count}")
|
138
137
|
|
139
138
|
# Calculate product of probabilities and distance
|
140
139
|
prob_product = prob1 * prob2
|
@@ -24,10 +24,10 @@ def get_preprocess_function(model):
|
|
24
24
|
model_config = model.get_config()
|
25
25
|
if "name" in model_config:
|
26
26
|
model_name = model_config["name"].lower()
|
27
|
-
print(f"Model name: {model_name}")
|
27
|
+
# print(f"Model name: {model_name}")
|
28
28
|
for key in preprocess_map.keys():
|
29
29
|
if key in model_name:
|
30
|
-
print(f"Detected model type: {key}")
|
30
|
+
# print(f"Detected model type: {key}")
|
31
31
|
return preprocess_map[key]
|
32
32
|
|
33
33
|
for layer in model.layers:
|
@@ -35,7 +35,7 @@ def get_preprocess_function(model):
|
|
35
35
|
print(f"Checking layer: {layer_name}")
|
36
36
|
for model_name in preprocess_map.keys():
|
37
37
|
if model_name in layer_name:
|
38
|
-
print(f"Detected model type: {model_name}")
|
38
|
+
# print(f"Detected model type: {model_name}")
|
39
39
|
return preprocess_map[model_name]
|
40
40
|
|
41
41
|
print("No supported model type found in the configuration. Falling back to generic normalization.")
|
@@ -25,15 +25,15 @@ def get_preprocess_function(model):
|
|
25
25
|
model_config = model.get_config()
|
26
26
|
if "name" in model_config:
|
27
27
|
model_name = model_config["name"].lower()
|
28
|
-
print(f"Model name: {model_name}")
|
28
|
+
# print(f"Model name: {model_name}")
|
29
29
|
for key in preprocess_map.keys():
|
30
30
|
if key in model_name:
|
31
|
-
print(f"Detected model type: {key}")
|
31
|
+
# print(f"Detected model type: {key}")
|
32
32
|
return preprocess_map[key]
|
33
33
|
|
34
34
|
for layer in model.layers:
|
35
35
|
layer_name = layer.name.lower()
|
36
|
-
print(f"Checking layer: {layer_name}")
|
36
|
+
# print(f"Checking layer: {layer_name}")
|
37
37
|
for model_name in preprocess_map.keys():
|
38
38
|
if model_name in layer_name:
|
39
39
|
print(f"Detected model type: {model_name}")
|
@@ -2,8 +2,9 @@ BETTER_NMA/__init__.py,sha256=ePaQnto0n4hccz2490Z7bxwcbtONVAa6nWqg7SL4W1Y,428
|
|
2
2
|
BETTER_NMA/adversarial_score.py,sha256=qgScTqS-aJ2q4kFom505hBtonVzKK67fGS09J1_-G3o,875
|
3
3
|
BETTER_NMA/change_cluster_name.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
BETTER_NMA/detect_attack.py,sha256=s7YwTVMJFABSMt2aISR-zaIUxFaSWm9oODc9yF12KPY,4327
|
5
|
+
BETTER_NMA/explaination_score.py,sha256=GWtncYjIj8UF94g1UxtNRnyK7Cvh3PqVcShle_vLhUM,2003
|
5
6
|
BETTER_NMA/find_lca.py,sha256=UlyftOJmbSPuXzxvcheRb_IrdCqBsaSQHLchIRZIR-0,812
|
6
|
-
BETTER_NMA/main.py,sha256=
|
7
|
+
BETTER_NMA/main.py,sha256=2d3O8Nc6J9YTWXIsqIGIIWDkL1V3pWN1hhcDde873Nc,11525
|
7
8
|
BETTER_NMA/nma_creator.py,sha256=M-LlZGRkxhGYLHpaXTNoZj9AUH7uvev7hq7tbILWMLI,5137
|
8
9
|
BETTER_NMA/plot.py,sha256=nj2ca-ybzGMlo6HhCngyjGUNaJDmfsPxF5ad9xpxzvE,4383
|
9
10
|
BETTER_NMA/query_image.py,sha256=13AQ9-8QdzaIwH5-ELX3z3iJBP8nTDe-SMtwQve-1ek,906
|
@@ -11,19 +12,19 @@ BETTER_NMA/train_adversarial_detector.py,sha256=nMaQ-Pm2vP84qNR1GoKQiVPpmMC3rdor
|
|
11
12
|
BETTER_NMA/white_box_testing.py,sha256=zfhK8G-2cJH1AMevPywVnc05IhSqttf3YxQ6abdpM78,3524
|
12
13
|
BETTER_NMA/utilss/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
13
14
|
BETTER_NMA/utilss/models_utils.py,sha256=gBXY2LFH4iR-2GZmHeUnnB5n9t3VdjIc9sugHDrD3AM,671
|
14
|
-
BETTER_NMA/utilss/photos_uitls.py,sha256=
|
15
|
-
BETTER_NMA/utilss/photos_utils.py,sha256=
|
15
|
+
BETTER_NMA/utilss/photos_uitls.py,sha256=wxmIIKFgAKYkcYaK95UMjtY-LZS6NDVveKrHBQV8Q70,3166
|
16
|
+
BETTER_NMA/utilss/photos_utils.py,sha256=4EjDHbMjrJ8P9y-X4H05P4wez4uKNit60UGnu3sKsys,4412
|
16
17
|
BETTER_NMA/utilss/verbal_explanation.py,sha256=_hrYZUjBUYOfuGr7t5r-DACooR5d60dRtGfUj7FbeZw,549
|
17
18
|
BETTER_NMA/utilss/wordnet_utils.py,sha256=_A_gbdR7tf5tiyN2Oe5sB4vvBkWnr4KUl5e9iq5ft8c,5535
|
18
19
|
BETTER_NMA/utilss/classes/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
|
-
BETTER_NMA/utilss/classes/adversarial_dataset.py,sha256=
|
20
|
+
BETTER_NMA/utilss/classes/adversarial_dataset.py,sha256=LKmyseQetyBdoKrl7Q4MaTc7EzMZcMgqNdtCsdAvKHA,2299
|
20
21
|
BETTER_NMA/utilss/classes/adversarial_detector.py,sha256=BE_SxNEwcvuHERBiefefOmk1k6NJSo6juehkAjkEHuQ,2331
|
21
22
|
BETTER_NMA/utilss/classes/dendrogram.py,sha256=vtKBFfwzcz8k01Goc83pZlWC2pO86endTJURlkUWVQI,5141
|
22
23
|
BETTER_NMA/utilss/classes/edges_dataframe.py,sha256=q-RQ6beOeZeIgdEzwi8T5Ag2NBFySv7-ITD5m989nl4,1896
|
23
|
-
BETTER_NMA/utilss/classes/score_calculator.py,sha256=
|
24
|
+
BETTER_NMA/utilss/classes/score_calculator.py,sha256=zgZaFgFJeok2RzXFm9OE5pOuXY3euIXEiLGI44q30JM,5927
|
24
25
|
BETTER_NMA/utilss/classes/whitebox_testing.py,sha256=4WSEjQ5gl6f8xzWADAagZ3WtMHE889rW-zcYld9REnw,1367
|
25
26
|
BETTER_NMA/utilss/classes/preprocessing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
26
|
-
BETTER_NMA/utilss/classes/preprocessing/batch_predictor.py,sha256=
|
27
|
+
BETTER_NMA/utilss/classes/preprocessing/batch_predictor.py,sha256=Cw5yzxLCPYGCTvOWclyCLT4qNgCCTSbyr-BNHapkwn0,957
|
27
28
|
BETTER_NMA/utilss/classes/preprocessing/graph_builder.py,sha256=ILumiBY9BUIOxrIvq8C-8n945pK-t94Et6gZwJB-364,1672
|
28
29
|
BETTER_NMA/utilss/classes/preprocessing/heap_processor.py,sha256=KblmkVWVfMYtpZa4Wy1Ry0lVfdSr6h8LySt4S-lvIGo,1064
|
29
30
|
BETTER_NMA/utilss/classes/preprocessing/hierarchical_clustering_builder.py,sha256=YAIElJS_fSffIb3D2N1OZu9U6z7RYrHQTfB6bH4-VPI,4027
|
@@ -32,7 +33,7 @@ BETTER_NMA/utilss/classes/preprocessing/z_builder.py,sha256=T8ETfL7mMOgEj7oYNsw6
|
|
32
33
|
BETTER_NMA/utilss/enums/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
33
34
|
BETTER_NMA/utilss/enums/explanation_method.py,sha256=Ang-rjvxO4AJ1IH4mwS8sNpSwt9jn3PlqFbPPT-R9I8,150
|
34
35
|
BETTER_NMA/utilss/enums/heap_types.py,sha256=0z1d2qu1ZCbpWRXKD1dTopn3M4G1CxRQW9HWxVxyPIA,88
|
35
|
-
BETTER_NMA-1.0.
|
36
|
-
BETTER_NMA-1.0.
|
37
|
-
BETTER_NMA-1.0.
|
38
|
-
BETTER_NMA-1.0.
|
36
|
+
BETTER_NMA-1.0.2.dist-info/METADATA,sha256=9_3kUtQ3Cw2TL_3wq4l0ICWgJtNncp197lFQUrsCfas,5100
|
37
|
+
BETTER_NMA-1.0.2.dist-info/WHEEL,sha256=PZUExdf71Ui_so67QXpySuHtCi3-J3wvF4ORK6k_S8U,91
|
38
|
+
BETTER_NMA-1.0.2.dist-info/top_level.txt,sha256=SVRNqWPvCnynWVyXNAYnf9CSQIvMAvE6iyyiGHodQgY,11
|
39
|
+
BETTER_NMA-1.0.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|