BETTER-NMA 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- BETTER_NMA/__init__.py +15 -0
- BETTER_NMA/adversarial_score.py +19 -0
- BETTER_NMA/change_cluster_name.py +0 -0
- BETTER_NMA/detect_attack.py +108 -0
- BETTER_NMA/find_lca.py +21 -0
- BETTER_NMA/main.py +285 -0
- BETTER_NMA/nma_creator.py +108 -0
- BETTER_NMA/plot.py +131 -0
- BETTER_NMA/query_image.py +22 -0
- BETTER_NMA/train_adversarial_detector.py +21 -0
- BETTER_NMA/utilss/__init__.py +0 -0
- BETTER_NMA/utilss/classes/__init__.py +0 -0
- BETTER_NMA/utilss/classes/adversarial_dataset.py +61 -0
- BETTER_NMA/utilss/classes/adversarial_detector.py +63 -0
- BETTER_NMA/utilss/classes/dendrogram.py +131 -0
- BETTER_NMA/utilss/classes/edges_dataframe.py +53 -0
- BETTER_NMA/utilss/classes/preprocessing/__init__.py +0 -0
- BETTER_NMA/utilss/classes/preprocessing/batch_predictor.py +28 -0
- BETTER_NMA/utilss/classes/preprocessing/graph_builder.py +46 -0
- BETTER_NMA/utilss/classes/preprocessing/heap_processor.py +30 -0
- BETTER_NMA/utilss/classes/preprocessing/hierarchical_clustering_builder.py +102 -0
- BETTER_NMA/utilss/classes/preprocessing/tree_node.py +71 -0
- BETTER_NMA/utilss/classes/preprocessing/z_builder.py +93 -0
- BETTER_NMA/utilss/classes/score_calculator.py +165 -0
- BETTER_NMA/utilss/classes/whitebox_testing.py +35 -0
- BETTER_NMA/utilss/enums/__init__.py +0 -0
- BETTER_NMA/utilss/enums/explanation_method.py +6 -0
- BETTER_NMA/utilss/enums/heap_types.py +5 -0
- BETTER_NMA/utilss/models_utils.py +18 -0
- BETTER_NMA/utilss/photos_uitls.py +72 -0
- BETTER_NMA/utilss/photos_utils.py +104 -0
- BETTER_NMA/utilss/verbal_explanation.py +15 -0
- BETTER_NMA/utilss/wordnet_utils.py +177 -0
- BETTER_NMA/white_box_testing.py +101 -0
- BETTER_NMA-1.0.0.dist-info/METADATA +11 -0
- BETTER_NMA-1.0.0.dist-info/RECORD +38 -0
- BETTER_NMA-1.0.0.dist-info/WHEEL +5 -0
- BETTER_NMA-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,22 @@
|
|
1
|
+
from .utilss.verbal_explanation import get_verbal_explanation
|
2
|
+
from .utilss.photos_utils import preprocess_loaded_image
|
3
|
+
from .utilss.models_utils import get_top_k_predictions
|
4
|
+
|
5
|
+
def query_image(image, model, labels, dendrogram_object, top_k=5):
|
6
|
+
# predict image
|
7
|
+
try:
|
8
|
+
preprocessed_image, pil_image = preprocess_loaded_image(model, image)
|
9
|
+
predictions = get_top_k_predictions(
|
10
|
+
model, preprocessed_image, labels)
|
11
|
+
top_label = predictions[0][0] # Top label
|
12
|
+
top_k_predictions = predictions[:top_k]
|
13
|
+
|
14
|
+
consistency = dendrogram_object.find_name_hierarchy(
|
15
|
+
dendrogram_object.Z_tree_format, top_label)
|
16
|
+
|
17
|
+
explanation = get_verbal_explanation(consistency)
|
18
|
+
|
19
|
+
return top_k_predictions, explanation
|
20
|
+
except Exception as e:
|
21
|
+
print("Error occurred while querying image:", e)
|
22
|
+
return None
|
@@ -0,0 +1,21 @@
|
|
1
|
+
from .utilss.classes.adversarial_detector import AdversarialDetector
|
2
|
+
from .utilss.classes.adversarial_dataset import AdversarialDataset
|
3
|
+
|
4
|
+
def _create_adversarial_dataset(Z_matrix, clean_images, adversarial_images, model, labels) -> dict:
|
5
|
+
adversarial_dataset = AdversarialDataset(model, clean_images, adversarial_images, Z_matrix, labels)
|
6
|
+
X_train, y_train, X_test, y_test = adversarial_dataset.create_logistic_regression_dataset()
|
7
|
+
result = {
|
8
|
+
"X_train": X_train,
|
9
|
+
"y_train": y_train,
|
10
|
+
"X_test": X_test,
|
11
|
+
"y_test": y_test
|
12
|
+
}
|
13
|
+
return result
|
14
|
+
|
15
|
+
def create_logistic_regression_detector(Z_matrix, model, clean_images, adversarial_images, labels):
|
16
|
+
adversarial_dataset = _create_adversarial_dataset(Z_matrix, clean_images, adversarial_images, model, labels)
|
17
|
+
adversarial_detector = AdversarialDetector(adversarial_dataset)
|
18
|
+
|
19
|
+
print("Adversarial detector trained successfully!")
|
20
|
+
|
21
|
+
return adversarial_detector
|
File without changes
|
File without changes
|
@@ -0,0 +1,61 @@
|
|
1
|
+
from sklearn.model_selection import train_test_split
|
2
|
+
import numpy as np
|
3
|
+
from .score_calculator import ScoreCalculator
|
4
|
+
|
5
|
+
class AdversarialDataset:
|
6
|
+
def __init__(self, model, clear_images, adversarial_images, Z_full, labels):
|
7
|
+
self.model = model
|
8
|
+
self.clear_images = clear_images
|
9
|
+
self.adversarial_images = adversarial_images
|
10
|
+
self.score_calculator = ScoreCalculator(Z_full=Z_full, class_names=labels)
|
11
|
+
|
12
|
+
def create_logistic_regression_dataset(self):
|
13
|
+
scores = []
|
14
|
+
labels = []
|
15
|
+
|
16
|
+
print("getting preprocess function...")
|
17
|
+
|
18
|
+
try:
|
19
|
+
for image in self.clear_images[:50]:
|
20
|
+
# Add batch dimension for model prediction
|
21
|
+
image_batch = np.expand_dims(image, axis=0)
|
22
|
+
score = self.score_calculator.calculate_adversarial_score(self.model.predict(image_batch))
|
23
|
+
scores.append(score)
|
24
|
+
labels.append(0)
|
25
|
+
except Exception as e:
|
26
|
+
print(f"Error processing clean image: {e}")
|
27
|
+
|
28
|
+
# Generate features for PGD attacks
|
29
|
+
print("Generating attack features...")
|
30
|
+
try:
|
31
|
+
for adv_image in self.adversarial_images[:50]:
|
32
|
+
# Add batch dimension for model prediction
|
33
|
+
adv_image_batch = np.expand_dims(adv_image, axis=0)
|
34
|
+
score = self.score_calculator.calculate_adversarial_score(self.model.predict(adv_image_batch))
|
35
|
+
scores.append(score)
|
36
|
+
labels.append(1)
|
37
|
+
except Exception as e:
|
38
|
+
print(f"Error processing PGD attack on image: {e}")
|
39
|
+
|
40
|
+
print("labels:", labels)
|
41
|
+
print("scores:", scores)
|
42
|
+
|
43
|
+
# Convert to numpy arrays
|
44
|
+
X = np.array(scores)
|
45
|
+
y = np.array(labels)
|
46
|
+
|
47
|
+
# Reshape X to ensure it is 2D
|
48
|
+
if len(X.shape) == 1:
|
49
|
+
X = X.reshape(-1, 1)
|
50
|
+
|
51
|
+
# Split into training and test sets
|
52
|
+
X_train, X_test, y_train, y_test = train_test_split(
|
53
|
+
X, y, test_size=0.3, random_state=42
|
54
|
+
)
|
55
|
+
|
56
|
+
print(f"Training data shape: {X_train.shape}")
|
57
|
+
print(f"Clean samples: {sum(y_train == 0)}, Adversarial samples: {sum(y_train == 1)}")
|
58
|
+
print(f"Test data shape: {X_test.shape}")
|
59
|
+
print(f"Clean samples: {sum(y_test == 0)}, Adversarial samples: {sum(y_test == 1)}")
|
60
|
+
|
61
|
+
return X_train, y_train, X_test, y_test
|
@@ -0,0 +1,63 @@
|
|
1
|
+
from sklearn.linear_model import LogisticRegression
|
2
|
+
from sklearn.metrics import roc_curve
|
3
|
+
import numpy as np
|
4
|
+
|
5
|
+
class AdversarialDetector:
|
6
|
+
def __init__(self, dataset):
|
7
|
+
self.detector, self.threshold = self.train_adversarial_detector(dataset)
|
8
|
+
|
9
|
+
|
10
|
+
def predict(self, X):
|
11
|
+
# Predict probabilities
|
12
|
+
if self.detector is None:
|
13
|
+
raise ValueError("Detector model is not trained or loaded.")
|
14
|
+
if X is None or len(X) == 0:
|
15
|
+
raise ValueError("Input data is empty or None.")
|
16
|
+
|
17
|
+
y_pred_proba = self.detector.predict_proba(X)[:, 1]
|
18
|
+
# Apply the custom threshold
|
19
|
+
return (y_pred_proba >= self.threshold).astype(int)
|
20
|
+
|
21
|
+
def predict_proba(self, X):
|
22
|
+
if self.detector is None:
|
23
|
+
raise ValueError("Detector model is not trained or loaded.")
|
24
|
+
if X is None or len(X) == 0:
|
25
|
+
raise ValueError("Input data is empty or None.")
|
26
|
+
# Return predicted probabilities
|
27
|
+
return self.detector.predict_proba(X)
|
28
|
+
|
29
|
+
def train_adversarial_detector(self, dataset):
|
30
|
+
"""
|
31
|
+
Train a logistic regression model to detect adversarial examples across different attack types.
|
32
|
+
|
33
|
+
Parameters:
|
34
|
+
- model: The model being attacked
|
35
|
+
- Z_full: Hierarchical clustering data
|
36
|
+
- class_names: List of class names
|
37
|
+
- num_samples: Number of samples to use for training
|
38
|
+
|
39
|
+
Returns:
|
40
|
+
- Trained detector model and evaluation metrics
|
41
|
+
"""
|
42
|
+
|
43
|
+
X_train, y_train = dataset['X_train'], dataset['y_train']
|
44
|
+
|
45
|
+
# Train logistic regression model
|
46
|
+
print("Training adversarial detector...")
|
47
|
+
detector = LogisticRegression(max_iter=1000, class_weight='balanced')
|
48
|
+
detector.fit(X_train, y_train)
|
49
|
+
|
50
|
+
X_test, y_test = dataset['X_test'], dataset['y_test']
|
51
|
+
|
52
|
+
# Predict probabilities for the test set
|
53
|
+
y_pred_proba = detector.predict_proba(X_test)[:, 1]
|
54
|
+
|
55
|
+
# Compute ROC curve and AUC
|
56
|
+
fpr, tpr, thresholds = roc_curve(y_test, y_pred_proba)
|
57
|
+
|
58
|
+
# Find the optimal threshold (closest to top-left corner)
|
59
|
+
optimal_idx = np.argmax(tpr - fpr)
|
60
|
+
optimal_threshold = thresholds[optimal_idx]
|
61
|
+
|
62
|
+
return detector, optimal_threshold
|
63
|
+
|
@@ -0,0 +1,131 @@
|
|
1
|
+
from scipy.cluster.hierarchy import to_tree
|
2
|
+
from ..wordnet_utils import process_hierarchy
|
3
|
+
import json
|
4
|
+
import os
|
5
|
+
import pickle
|
6
|
+
|
7
|
+
class Dendrogram:
|
8
|
+
def __init__(self, Z=None):
|
9
|
+
self.Z = Z
|
10
|
+
self.Z_tree_format = None
|
11
|
+
self.dendrogram_filename = None
|
12
|
+
|
13
|
+
def _build_tree_format(self, node, labels):
|
14
|
+
if node.is_leaf():
|
15
|
+
return {
|
16
|
+
"id": node.id,
|
17
|
+
"name": labels[node.id],
|
18
|
+
}
|
19
|
+
else:
|
20
|
+
return {
|
21
|
+
"id": node.id,
|
22
|
+
"name": f"cluster_{node.id}",
|
23
|
+
"children": [self._build_tree_format(node.get_left(), labels), self._build_tree_format(node.get_right(), labels)],
|
24
|
+
"value": node.dist
|
25
|
+
}
|
26
|
+
|
27
|
+
def build_tree_hierarchy(self, linkage_matrix, labels):
|
28
|
+
tree, nodes = to_tree(linkage_matrix, rd=True)
|
29
|
+
self.Z_tree_format = self._build_tree_format(tree, labels)
|
30
|
+
self.Z_tree_format = process_hierarchy(self.Z_tree_format)
|
31
|
+
return self.Z_tree_format
|
32
|
+
|
33
|
+
def filter_dendrogram_by_labels(self, full_data, target_labels):
|
34
|
+
def contains_target_label(node):
|
35
|
+
if 'children' not in node:
|
36
|
+
return node.get('name') in target_labels
|
37
|
+
for child in node.get('children', []):
|
38
|
+
if contains_target_label(child):
|
39
|
+
return True
|
40
|
+
return False
|
41
|
+
def filter_tree(node):
|
42
|
+
if not contains_target_label(node):
|
43
|
+
return None
|
44
|
+
new_node = {
|
45
|
+
'id': node.get('id'),
|
46
|
+
'name': node.get('name')
|
47
|
+
}
|
48
|
+
if 'value' in node:
|
49
|
+
new_node['value'] = node.get('value')
|
50
|
+
if 'children' not in node:
|
51
|
+
return new_node
|
52
|
+
filtered_children = []
|
53
|
+
for child in node.get('children', []):
|
54
|
+
filtered_child = filter_tree(child)
|
55
|
+
if filtered_child:
|
56
|
+
filtered_children.append(filtered_child)
|
57
|
+
if filtered_children:
|
58
|
+
new_node['children'] = filtered_children
|
59
|
+
return new_node
|
60
|
+
return filter_tree(full_data)
|
61
|
+
|
62
|
+
def merge_clusters(self, node):
|
63
|
+
if node is None:
|
64
|
+
return None
|
65
|
+
if "children" not in node:
|
66
|
+
return node
|
67
|
+
merged_children = []
|
68
|
+
for child in node["children"]:
|
69
|
+
merged_child = self.merge_clusters(child)
|
70
|
+
if merged_child:
|
71
|
+
merged_children.append(merged_child)
|
72
|
+
if all(c.get("value", 0) == 100 for c in merged_children):
|
73
|
+
node["children"] = [grandchild for child in merged_children for grandchild in child.get("children", [])]
|
74
|
+
else:
|
75
|
+
node["children"] = merged_children
|
76
|
+
if len(node["children"]) == 1:
|
77
|
+
return node["children"][0]
|
78
|
+
return node
|
79
|
+
|
80
|
+
def get_sub_dendrogram_formatted(self, selected_labels):
|
81
|
+
filtered_tree = self.filter_dendrogram_by_labels(self.Z_tree_format, selected_labels)
|
82
|
+
if filtered_tree is None:
|
83
|
+
raise ValueError(f"No clusters found for the selected labels: {selected_labels}")
|
84
|
+
filtered_tree = self.merge_clusters(filtered_tree)
|
85
|
+
filtered_tree_json = json.dumps(filtered_tree, indent=2)
|
86
|
+
return filtered_tree_json
|
87
|
+
|
88
|
+
def find_name_hierarchy(self, node, target_name):
|
89
|
+
if node.get('name') == target_name:
|
90
|
+
return [target_name]
|
91
|
+
if 'children' in node:
|
92
|
+
for child in node['children']:
|
93
|
+
result = self.find_name_hierarchy(child, target_name)
|
94
|
+
if result is not None:
|
95
|
+
if node.get('name'):
|
96
|
+
result.append(node['name'])
|
97
|
+
return result
|
98
|
+
return None
|
99
|
+
|
100
|
+
def rename_cluster(self, cluster_id, new_name):
|
101
|
+
print(f"Renaming cluster {cluster_id} to {new_name}")
|
102
|
+
def collect_names(node, names):
|
103
|
+
names.add(node.get('name'))
|
104
|
+
for child in node.get('children', []):
|
105
|
+
collect_names(child, names)
|
106
|
+
existing_names = set()
|
107
|
+
collect_names(self.Z_tree_format, existing_names)
|
108
|
+
unique_name = new_name
|
109
|
+
suffix = 1
|
110
|
+
while unique_name in existing_names:
|
111
|
+
unique_name = f"{new_name}_{suffix}"
|
112
|
+
suffix += 1
|
113
|
+
|
114
|
+
def rename_node(node):
|
115
|
+
if node.get('id') == cluster_id:
|
116
|
+
node['name'] = unique_name
|
117
|
+
for child in node.get('children', []):
|
118
|
+
rename_node(child)
|
119
|
+
rename_node(self.Z_tree_format)
|
120
|
+
return self.Z_tree_format
|
121
|
+
|
122
|
+
def get_node_name(self, node_id):
|
123
|
+
def find_name(node):
|
124
|
+
if node.get('id') == node_id:
|
125
|
+
return node.get('name')
|
126
|
+
for child in node.get('children', []):
|
127
|
+
result = find_name(child)
|
128
|
+
if result is not None:
|
129
|
+
return result
|
130
|
+
return None
|
131
|
+
return find_name(self.Z_tree_format)
|
@@ -0,0 +1,53 @@
|
|
1
|
+
import pandas as pd
|
2
|
+
import os
|
3
|
+
|
4
|
+
class EdgesDataframe:
|
5
|
+
def __init__(self, model_filename, edges_df_path):
|
6
|
+
"""
|
7
|
+
Initialize EdgeDataframe handler.
|
8
|
+
|
9
|
+
Args:
|
10
|
+
model_filename: Name/path of the model
|
11
|
+
edges_df_path: Path to the edges dataframe CSV file
|
12
|
+
"""
|
13
|
+
self.model_filename = model_filename
|
14
|
+
self.edges_df_path = edges_df_path
|
15
|
+
self.dataframe = None
|
16
|
+
|
17
|
+
def load_dataframe(self):
|
18
|
+
"""Load the edges dataframe from CSV file."""
|
19
|
+
if not os.path.exists(self.edges_df_path):
|
20
|
+
raise FileNotFoundError(f"Edges dataframe not found at {self.edges_df_path}")
|
21
|
+
|
22
|
+
self.dataframe = pd.read_csv(self.edges_df_path)
|
23
|
+
return self.dataframe
|
24
|
+
|
25
|
+
def get_dataframe(self):
|
26
|
+
"""Get the loaded dataframe."""
|
27
|
+
if self.dataframe is None:
|
28
|
+
raise ValueError("Dataframe not loaded. Call load_dataframe() first.")
|
29
|
+
return self.dataframe
|
30
|
+
|
31
|
+
def save_dataframe(self, path=None):
|
32
|
+
"""Save the dataframe to CSV."""
|
33
|
+
if self.dataframe is None:
|
34
|
+
raise ValueError("No dataframe to save.")
|
35
|
+
|
36
|
+
save_path = path or self.edges_df_path
|
37
|
+
self.dataframe.to_csv(save_path, index=False)
|
38
|
+
return save_path
|
39
|
+
|
40
|
+
def filter_by_labels(self, source_labels=None, target_labels=None):
|
41
|
+
"""Filter dataframe by source and/or target labels."""
|
42
|
+
if self.dataframe is None:
|
43
|
+
raise ValueError("Dataframe not loaded.")
|
44
|
+
|
45
|
+
filtered_df = self.dataframe.copy()
|
46
|
+
|
47
|
+
if source_labels:
|
48
|
+
filtered_df = filtered_df[filtered_df['source'].isin(source_labels)]
|
49
|
+
|
50
|
+
if target_labels:
|
51
|
+
filtered_df = filtered_df[filtered_df['target'].isin(target_labels)]
|
52
|
+
|
53
|
+
return filtered_df
|
File without changes
|
@@ -0,0 +1,28 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import tensorflow as tf
|
3
|
+
|
4
|
+
|
5
|
+
class BatchPredictor:
|
6
|
+
def __init__(self, model, batch_size=32):
|
7
|
+
self.model = model
|
8
|
+
self.batch_size = batch_size
|
9
|
+
self.buffer_images = [] # To store images
|
10
|
+
self.buffer_labels = [] # To store corresponding labels
|
11
|
+
self.buffer_results = [] # To store batch results
|
12
|
+
|
13
|
+
def get_top_predictions(self, X, labels, top_k, graph_threshold):
|
14
|
+
batch_preds = self.model.predict(np.array(X))
|
15
|
+
batch_results = []
|
16
|
+
for pred in batch_preds:
|
17
|
+
top_indices = pred.argsort()[-top_k:][::-1]
|
18
|
+
valid_indices = [i for i in top_indices if i < len(labels)]
|
19
|
+
|
20
|
+
top_predictions = [
|
21
|
+
(i, labels[i], pred[i])
|
22
|
+
for i in valid_indices
|
23
|
+
if pred[i] >= graph_threshold
|
24
|
+
]
|
25
|
+
|
26
|
+
batch_results.append(top_predictions)
|
27
|
+
|
28
|
+
return batch_results
|
@@ -0,0 +1,46 @@
|
|
1
|
+
from ...enums.explanation_method import ExplanationMethod
|
2
|
+
|
3
|
+
class GraphBuilder:
|
4
|
+
def __init__(self, graph_type, infinity):
|
5
|
+
self.graph_type = graph_type
|
6
|
+
self.infinity = infinity
|
7
|
+
|
8
|
+
def create_edge_weight(self, pred_prob):
|
9
|
+
if self.graph_type == ExplanationMethod.DISSIMILARITY.value:
|
10
|
+
return 1 - pred_prob
|
11
|
+
elif self.graph_type == ExplanationMethod.COUNT.value:
|
12
|
+
return 1
|
13
|
+
return pred_prob
|
14
|
+
|
15
|
+
def update_graph(self, graph, source_label, target_label, probability, image_id):
|
16
|
+
if source_label == target_label:
|
17
|
+
return None
|
18
|
+
|
19
|
+
weight = self.create_edge_weight(probability)
|
20
|
+
|
21
|
+
if graph.are_adjacent(source_label, target_label):
|
22
|
+
edge_id = graph.get_eid(source_label, target_label)
|
23
|
+
graph.es[edge_id]["weight"] += weight
|
24
|
+
else:
|
25
|
+
graph.add_edge(source_label, target_label, weight=weight)
|
26
|
+
|
27
|
+
edge_data = {
|
28
|
+
"image_id": image_id,
|
29
|
+
"source": source_label,
|
30
|
+
"target": target_label,
|
31
|
+
"target_probability": probability,
|
32
|
+
}
|
33
|
+
|
34
|
+
return edge_data
|
35
|
+
|
36
|
+
|
37
|
+
def add_infinity_edges(self, graph, infinity_edges_labels, label, source_label):
|
38
|
+
if label == source_label:
|
39
|
+
return
|
40
|
+
|
41
|
+
if label not in infinity_edges_labels:
|
42
|
+
if graph.are_adjacent(source_label, label):
|
43
|
+
edge_id = graph.get_eid(source_label, label)
|
44
|
+
graph.es[edge_id]["weight"] += self.infinity
|
45
|
+
else:
|
46
|
+
graph.add_edge(source_label, label, weight=self.infinity)
|
@@ -0,0 +1,30 @@
|
|
1
|
+
import heapq
|
2
|
+
import copy
|
3
|
+
|
4
|
+
from ...enums.heap_types import HeapType
|
5
|
+
from ...enums.explanation_method import ExplanationMethod
|
6
|
+
|
7
|
+
class HeapProcessor:
|
8
|
+
def __init__(self,graph, graph_type, labels):
|
9
|
+
self.heap_type = self._get_heap_type(graph_type)
|
10
|
+
self.heap = []
|
11
|
+
self.nodes_multiplier = -1 if self.heap_type == "max" else 1
|
12
|
+
|
13
|
+
self._process_edges(graph, labels)
|
14
|
+
|
15
|
+
|
16
|
+
def _get_heap_type(self, graph_type):
|
17
|
+
return HeapType.MINIMUM.value if graph_type == ExplanationMethod.DISSIMILARITY.value else HeapType.MAXIMUM.value
|
18
|
+
|
19
|
+
def _process_edges(self, graph, labels):
|
20
|
+
for edge in graph.es:
|
21
|
+
source = graph.vs[edge.source]["name"]
|
22
|
+
target = graph.vs[edge.target]["name"]
|
23
|
+
weight = edge["weight"] if "weight" in edge.attributes() else 0
|
24
|
+
heapq.heappush(self.heap, (self.nodes_multiplier * weight, source, target))
|
25
|
+
|
26
|
+
def get_heap(self):
|
27
|
+
return self.heap
|
28
|
+
|
29
|
+
def get_heap_copy(self):
|
30
|
+
return copy.deepcopy(self.heap)
|
@@ -0,0 +1,102 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import json
|
3
|
+
import os
|
4
|
+
from .tree_node import TreeNode
|
5
|
+
from ...enums.heap_types import HeapType
|
6
|
+
|
7
|
+
class HierarchicalClusteringBuilder:
|
8
|
+
def __init__(self, heap_processor, labels):
|
9
|
+
self.nodes = {}
|
10
|
+
self.next_cluster_id = len(labels)
|
11
|
+
heap = heap_processor.get_heap()
|
12
|
+
self.max_weight = max((abs(weight) for weight,_, _ in heap), default=0)
|
13
|
+
self.forest = []
|
14
|
+
self._build(heap_processor, labels)
|
15
|
+
|
16
|
+
def _initialize_nodes(self, labels):
|
17
|
+
for i, label in enumerate(labels):
|
18
|
+
self.nodes[label] = TreeNode(node_id=i, label=label)
|
19
|
+
|
20
|
+
def _get_root(self, node):
|
21
|
+
while node.parent is not None:
|
22
|
+
node = node.parent
|
23
|
+
return node
|
24
|
+
|
25
|
+
def _build(self, heap_processor, labels):
|
26
|
+
self._initialize_nodes(labels)
|
27
|
+
n = len(labels)
|
28
|
+
distances = np.full((n, n), np.inf)
|
29
|
+
np.fill_diagonal(distances, 0)
|
30
|
+
label_to_idx = {label: i for i, label in enumerate(labels)}
|
31
|
+
for weight, source, target in heap_processor.get_heap():
|
32
|
+
if heap_processor.heap_type == HeapType.MAXIMUM.value:
|
33
|
+
distance = self.max_weight - (-1*weight)
|
34
|
+
else:
|
35
|
+
distance = weight
|
36
|
+
src_idx = label_to_idx[source]
|
37
|
+
tgt_idx = label_to_idx[target]
|
38
|
+
distances[src_idx, tgt_idx] = distance
|
39
|
+
distances[tgt_idx, src_idx] = distance
|
40
|
+
max_dist = np.max(distances[~np.isinf(distances)])
|
41
|
+
distances[np.isinf(distances)] = max_dist * 2
|
42
|
+
active_clusters = {i: [i] for i in range(n)}
|
43
|
+
next_id = n
|
44
|
+
while len(active_clusters) > 1:
|
45
|
+
min_dist = float('inf')
|
46
|
+
closest_pair = None
|
47
|
+
for i in active_clusters:
|
48
|
+
for j in active_clusters:
|
49
|
+
if i < j:
|
50
|
+
dist_sum = 0
|
51
|
+
count = 0
|
52
|
+
for idx1 in active_clusters[i]:
|
53
|
+
for idx2 in active_clusters[j]:
|
54
|
+
dist_sum += distances[idx1, idx2]
|
55
|
+
count += 1
|
56
|
+
avg_dist = dist_sum / count
|
57
|
+
if avg_dist < min_dist:
|
58
|
+
min_dist = avg_dist
|
59
|
+
closest_pair = (i, j)
|
60
|
+
i, j = closest_pair
|
61
|
+
if len(active_clusters[i]) == 1:
|
62
|
+
node_i = self.nodes[labels[active_clusters[i][0]]]
|
63
|
+
else:
|
64
|
+
node_i = self.nodes[f"cluster_{i}"]
|
65
|
+
if len(active_clusters[j]) == 1:
|
66
|
+
node_j = self.nodes[labels[active_clusters[j][0]]]
|
67
|
+
else:
|
68
|
+
node_j = self.nodes[f"cluster_{j}"]
|
69
|
+
cluster_name = f"cluster_{next_id}"
|
70
|
+
new_node = TreeNode(
|
71
|
+
node_id=next_id,
|
72
|
+
label=cluster_name,
|
73
|
+
children=[node_i, node_j],
|
74
|
+
weight=min_dist
|
75
|
+
)
|
76
|
+
self.nodes[cluster_name] = new_node
|
77
|
+
node_i.parent = new_node
|
78
|
+
node_j.parent = new_node
|
79
|
+
merged_cluster = active_clusters[i] + active_clusters[j]
|
80
|
+
active_clusters[next_id] = merged_cluster
|
81
|
+
del active_clusters[i]
|
82
|
+
del active_clusters[j]
|
83
|
+
next_id += 1
|
84
|
+
final_id = list(active_clusters.keys())[0]
|
85
|
+
if final_id < n:
|
86
|
+
self.final_tree = self.nodes[labels[final_id]]
|
87
|
+
else:
|
88
|
+
self.final_tree = self.nodes[f"cluster_{final_id}"]
|
89
|
+
self.forest = [self.final_tree]
|
90
|
+
print(f"Built a forest with {len(self.forest)} tree{'s' if len(self.forest) > 1 else ''}")
|
91
|
+
|
92
|
+
|
93
|
+
def get_forest(self):
|
94
|
+
return self.forest
|
95
|
+
|
96
|
+
def get_tree(self, index=0):
|
97
|
+
if 0 <= index < len(self.forest):
|
98
|
+
return self.forest[index]
|
99
|
+
return None
|
100
|
+
|
101
|
+
def forest_size(self):
|
102
|
+
return len(self.forest)
|
@@ -0,0 +1,71 @@
|
|
1
|
+
from nltk.tree import Tree
|
2
|
+
|
3
|
+
class TreeNode(Tree):
|
4
|
+
def __init__(self, node_id, label=None, children=None, weight=0, parent=None):
|
5
|
+
self.node_name = label if label else f"cluster {node_id}"
|
6
|
+
super().__init__(self.node_name, children if children else [])
|
7
|
+
self.node_id = node_id
|
8
|
+
self.weight = weight
|
9
|
+
self.parent = parent
|
10
|
+
if children:
|
11
|
+
for child in children:
|
12
|
+
if isinstance(child, TreeNode):
|
13
|
+
child.parent = self
|
14
|
+
|
15
|
+
def __hash__(self):
|
16
|
+
return hash(self.node_id)
|
17
|
+
|
18
|
+
def __eq__(self, other):
|
19
|
+
if not isinstance(other, TreeNode):
|
20
|
+
return False
|
21
|
+
return self.node_id == other.node_id
|
22
|
+
|
23
|
+
def add_child(self, child):
|
24
|
+
if isinstance(child, TreeNode):
|
25
|
+
child.parent = self
|
26
|
+
self.append(child)
|
27
|
+
return self
|
28
|
+
|
29
|
+
@classmethod
|
30
|
+
def create_parent(cls, node_id, children, label=None, weight=0):
|
31
|
+
parent = cls(node_id, label, weight=weight)
|
32
|
+
for child in children:
|
33
|
+
parent.add_child(child)
|
34
|
+
return parent
|
35
|
+
|
36
|
+
def get_root(self):
|
37
|
+
current = self
|
38
|
+
while current.parent is not None:
|
39
|
+
current = current.parent
|
40
|
+
return current
|
41
|
+
|
42
|
+
def get_path_to_root(self):
|
43
|
+
path = []
|
44
|
+
current = self
|
45
|
+
while current is not None:
|
46
|
+
path.append(current)
|
47
|
+
current = current.parent
|
48
|
+
return path
|
49
|
+
|
50
|
+
def find_lca(self, other_node):
|
51
|
+
if not isinstance(other_node, TreeNode):
|
52
|
+
raise TypeError("Expected TreeNode node")
|
53
|
+
path1 = []
|
54
|
+
current = self
|
55
|
+
while current is not None:
|
56
|
+
path1.append(current)
|
57
|
+
current = current.parent
|
58
|
+
path2 = []
|
59
|
+
current = other_node
|
60
|
+
while current is not None:
|
61
|
+
path2.append(current)
|
62
|
+
current = current.parent
|
63
|
+
path1.reverse()
|
64
|
+
path2.reverse()
|
65
|
+
lca = None
|
66
|
+
for i in range(min(len(path1), len(path2))):
|
67
|
+
if path1[i] is path2[i]:
|
68
|
+
lca = path1[i]
|
69
|
+
else:
|
70
|
+
break
|
71
|
+
return lca
|