BETTER-NMA 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. BETTER_NMA/__init__.py +15 -0
  2. BETTER_NMA/adversarial_score.py +19 -0
  3. BETTER_NMA/change_cluster_name.py +0 -0
  4. BETTER_NMA/detect_attack.py +108 -0
  5. BETTER_NMA/find_lca.py +21 -0
  6. BETTER_NMA/main.py +285 -0
  7. BETTER_NMA/nma_creator.py +108 -0
  8. BETTER_NMA/plot.py +131 -0
  9. BETTER_NMA/query_image.py +22 -0
  10. BETTER_NMA/train_adversarial_detector.py +21 -0
  11. BETTER_NMA/utilss/__init__.py +0 -0
  12. BETTER_NMA/utilss/classes/__init__.py +0 -0
  13. BETTER_NMA/utilss/classes/adversarial_dataset.py +61 -0
  14. BETTER_NMA/utilss/classes/adversarial_detector.py +63 -0
  15. BETTER_NMA/utilss/classes/dendrogram.py +131 -0
  16. BETTER_NMA/utilss/classes/edges_dataframe.py +53 -0
  17. BETTER_NMA/utilss/classes/preprocessing/__init__.py +0 -0
  18. BETTER_NMA/utilss/classes/preprocessing/batch_predictor.py +28 -0
  19. BETTER_NMA/utilss/classes/preprocessing/graph_builder.py +46 -0
  20. BETTER_NMA/utilss/classes/preprocessing/heap_processor.py +30 -0
  21. BETTER_NMA/utilss/classes/preprocessing/hierarchical_clustering_builder.py +102 -0
  22. BETTER_NMA/utilss/classes/preprocessing/tree_node.py +71 -0
  23. BETTER_NMA/utilss/classes/preprocessing/z_builder.py +93 -0
  24. BETTER_NMA/utilss/classes/score_calculator.py +165 -0
  25. BETTER_NMA/utilss/classes/whitebox_testing.py +35 -0
  26. BETTER_NMA/utilss/enums/__init__.py +0 -0
  27. BETTER_NMA/utilss/enums/explanation_method.py +6 -0
  28. BETTER_NMA/utilss/enums/heap_types.py +5 -0
  29. BETTER_NMA/utilss/models_utils.py +18 -0
  30. BETTER_NMA/utilss/photos_uitls.py +72 -0
  31. BETTER_NMA/utilss/photos_utils.py +104 -0
  32. BETTER_NMA/utilss/verbal_explanation.py +15 -0
  33. BETTER_NMA/utilss/wordnet_utils.py +177 -0
  34. BETTER_NMA/white_box_testing.py +101 -0
  35. BETTER_NMA-1.0.0.dist-info/METADATA +11 -0
  36. BETTER_NMA-1.0.0.dist-info/RECORD +38 -0
  37. BETTER_NMA-1.0.0.dist-info/WHEEL +5 -0
  38. BETTER_NMA-1.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,22 @@
1
+ from .utilss.verbal_explanation import get_verbal_explanation
2
+ from .utilss.photos_utils import preprocess_loaded_image
3
+ from .utilss.models_utils import get_top_k_predictions
4
+
5
+ def query_image(image, model, labels, dendrogram_object, top_k=5):
6
+ # predict image
7
+ try:
8
+ preprocessed_image, pil_image = preprocess_loaded_image(model, image)
9
+ predictions = get_top_k_predictions(
10
+ model, preprocessed_image, labels)
11
+ top_label = predictions[0][0] # Top label
12
+ top_k_predictions = predictions[:top_k]
13
+
14
+ consistency = dendrogram_object.find_name_hierarchy(
15
+ dendrogram_object.Z_tree_format, top_label)
16
+
17
+ explanation = get_verbal_explanation(consistency)
18
+
19
+ return top_k_predictions, explanation
20
+ except Exception as e:
21
+ print("Error occurred while querying image:", e)
22
+ return None
@@ -0,0 +1,21 @@
1
+ from .utilss.classes.adversarial_detector import AdversarialDetector
2
+ from .utilss.classes.adversarial_dataset import AdversarialDataset
3
+
4
+ def _create_adversarial_dataset(Z_matrix, clean_images, adversarial_images, model, labels) -> dict:
5
+ adversarial_dataset = AdversarialDataset(model, clean_images, adversarial_images, Z_matrix, labels)
6
+ X_train, y_train, X_test, y_test = adversarial_dataset.create_logistic_regression_dataset()
7
+ result = {
8
+ "X_train": X_train,
9
+ "y_train": y_train,
10
+ "X_test": X_test,
11
+ "y_test": y_test
12
+ }
13
+ return result
14
+
15
+ def create_logistic_regression_detector(Z_matrix, model, clean_images, adversarial_images, labels):
16
+ adversarial_dataset = _create_adversarial_dataset(Z_matrix, clean_images, adversarial_images, model, labels)
17
+ adversarial_detector = AdversarialDetector(adversarial_dataset)
18
+
19
+ print("Adversarial detector trained successfully!")
20
+
21
+ return adversarial_detector
File without changes
File without changes
@@ -0,0 +1,61 @@
1
+ from sklearn.model_selection import train_test_split
2
+ import numpy as np
3
+ from .score_calculator import ScoreCalculator
4
+
5
+ class AdversarialDataset:
6
+ def __init__(self, model, clear_images, adversarial_images, Z_full, labels):
7
+ self.model = model
8
+ self.clear_images = clear_images
9
+ self.adversarial_images = adversarial_images
10
+ self.score_calculator = ScoreCalculator(Z_full=Z_full, class_names=labels)
11
+
12
+ def create_logistic_regression_dataset(self):
13
+ scores = []
14
+ labels = []
15
+
16
+ print("getting preprocess function...")
17
+
18
+ try:
19
+ for image in self.clear_images[:50]:
20
+ # Add batch dimension for model prediction
21
+ image_batch = np.expand_dims(image, axis=0)
22
+ score = self.score_calculator.calculate_adversarial_score(self.model.predict(image_batch))
23
+ scores.append(score)
24
+ labels.append(0)
25
+ except Exception as e:
26
+ print(f"Error processing clean image: {e}")
27
+
28
+ # Generate features for PGD attacks
29
+ print("Generating attack features...")
30
+ try:
31
+ for adv_image in self.adversarial_images[:50]:
32
+ # Add batch dimension for model prediction
33
+ adv_image_batch = np.expand_dims(adv_image, axis=0)
34
+ score = self.score_calculator.calculate_adversarial_score(self.model.predict(adv_image_batch))
35
+ scores.append(score)
36
+ labels.append(1)
37
+ except Exception as e:
38
+ print(f"Error processing PGD attack on image: {e}")
39
+
40
+ print("labels:", labels)
41
+ print("scores:", scores)
42
+
43
+ # Convert to numpy arrays
44
+ X = np.array(scores)
45
+ y = np.array(labels)
46
+
47
+ # Reshape X to ensure it is 2D
48
+ if len(X.shape) == 1:
49
+ X = X.reshape(-1, 1)
50
+
51
+ # Split into training and test sets
52
+ X_train, X_test, y_train, y_test = train_test_split(
53
+ X, y, test_size=0.3, random_state=42
54
+ )
55
+
56
+ print(f"Training data shape: {X_train.shape}")
57
+ print(f"Clean samples: {sum(y_train == 0)}, Adversarial samples: {sum(y_train == 1)}")
58
+ print(f"Test data shape: {X_test.shape}")
59
+ print(f"Clean samples: {sum(y_test == 0)}, Adversarial samples: {sum(y_test == 1)}")
60
+
61
+ return X_train, y_train, X_test, y_test
@@ -0,0 +1,63 @@
1
+ from sklearn.linear_model import LogisticRegression
2
+ from sklearn.metrics import roc_curve
3
+ import numpy as np
4
+
5
+ class AdversarialDetector:
6
+ def __init__(self, dataset):
7
+ self.detector, self.threshold = self.train_adversarial_detector(dataset)
8
+
9
+
10
+ def predict(self, X):
11
+ # Predict probabilities
12
+ if self.detector is None:
13
+ raise ValueError("Detector model is not trained or loaded.")
14
+ if X is None or len(X) == 0:
15
+ raise ValueError("Input data is empty or None.")
16
+
17
+ y_pred_proba = self.detector.predict_proba(X)[:, 1]
18
+ # Apply the custom threshold
19
+ return (y_pred_proba >= self.threshold).astype(int)
20
+
21
+ def predict_proba(self, X):
22
+ if self.detector is None:
23
+ raise ValueError("Detector model is not trained or loaded.")
24
+ if X is None or len(X) == 0:
25
+ raise ValueError("Input data is empty or None.")
26
+ # Return predicted probabilities
27
+ return self.detector.predict_proba(X)
28
+
29
+ def train_adversarial_detector(self, dataset):
30
+ """
31
+ Train a logistic regression model to detect adversarial examples across different attack types.
32
+
33
+ Parameters:
34
+ - model: The model being attacked
35
+ - Z_full: Hierarchical clustering data
36
+ - class_names: List of class names
37
+ - num_samples: Number of samples to use for training
38
+
39
+ Returns:
40
+ - Trained detector model and evaluation metrics
41
+ """
42
+
43
+ X_train, y_train = dataset['X_train'], dataset['y_train']
44
+
45
+ # Train logistic regression model
46
+ print("Training adversarial detector...")
47
+ detector = LogisticRegression(max_iter=1000, class_weight='balanced')
48
+ detector.fit(X_train, y_train)
49
+
50
+ X_test, y_test = dataset['X_test'], dataset['y_test']
51
+
52
+ # Predict probabilities for the test set
53
+ y_pred_proba = detector.predict_proba(X_test)[:, 1]
54
+
55
+ # Compute ROC curve and AUC
56
+ fpr, tpr, thresholds = roc_curve(y_test, y_pred_proba)
57
+
58
+ # Find the optimal threshold (closest to top-left corner)
59
+ optimal_idx = np.argmax(tpr - fpr)
60
+ optimal_threshold = thresholds[optimal_idx]
61
+
62
+ return detector, optimal_threshold
63
+
@@ -0,0 +1,131 @@
1
+ from scipy.cluster.hierarchy import to_tree
2
+ from ..wordnet_utils import process_hierarchy
3
+ import json
4
+ import os
5
+ import pickle
6
+
7
+ class Dendrogram:
8
+ def __init__(self, Z=None):
9
+ self.Z = Z
10
+ self.Z_tree_format = None
11
+ self.dendrogram_filename = None
12
+
13
+ def _build_tree_format(self, node, labels):
14
+ if node.is_leaf():
15
+ return {
16
+ "id": node.id,
17
+ "name": labels[node.id],
18
+ }
19
+ else:
20
+ return {
21
+ "id": node.id,
22
+ "name": f"cluster_{node.id}",
23
+ "children": [self._build_tree_format(node.get_left(), labels), self._build_tree_format(node.get_right(), labels)],
24
+ "value": node.dist
25
+ }
26
+
27
+ def build_tree_hierarchy(self, linkage_matrix, labels):
28
+ tree, nodes = to_tree(linkage_matrix, rd=True)
29
+ self.Z_tree_format = self._build_tree_format(tree, labels)
30
+ self.Z_tree_format = process_hierarchy(self.Z_tree_format)
31
+ return self.Z_tree_format
32
+
33
+ def filter_dendrogram_by_labels(self, full_data, target_labels):
34
+ def contains_target_label(node):
35
+ if 'children' not in node:
36
+ return node.get('name') in target_labels
37
+ for child in node.get('children', []):
38
+ if contains_target_label(child):
39
+ return True
40
+ return False
41
+ def filter_tree(node):
42
+ if not contains_target_label(node):
43
+ return None
44
+ new_node = {
45
+ 'id': node.get('id'),
46
+ 'name': node.get('name')
47
+ }
48
+ if 'value' in node:
49
+ new_node['value'] = node.get('value')
50
+ if 'children' not in node:
51
+ return new_node
52
+ filtered_children = []
53
+ for child in node.get('children', []):
54
+ filtered_child = filter_tree(child)
55
+ if filtered_child:
56
+ filtered_children.append(filtered_child)
57
+ if filtered_children:
58
+ new_node['children'] = filtered_children
59
+ return new_node
60
+ return filter_tree(full_data)
61
+
62
+ def merge_clusters(self, node):
63
+ if node is None:
64
+ return None
65
+ if "children" not in node:
66
+ return node
67
+ merged_children = []
68
+ for child in node["children"]:
69
+ merged_child = self.merge_clusters(child)
70
+ if merged_child:
71
+ merged_children.append(merged_child)
72
+ if all(c.get("value", 0) == 100 for c in merged_children):
73
+ node["children"] = [grandchild for child in merged_children for grandchild in child.get("children", [])]
74
+ else:
75
+ node["children"] = merged_children
76
+ if len(node["children"]) == 1:
77
+ return node["children"][0]
78
+ return node
79
+
80
+ def get_sub_dendrogram_formatted(self, selected_labels):
81
+ filtered_tree = self.filter_dendrogram_by_labels(self.Z_tree_format, selected_labels)
82
+ if filtered_tree is None:
83
+ raise ValueError(f"No clusters found for the selected labels: {selected_labels}")
84
+ filtered_tree = self.merge_clusters(filtered_tree)
85
+ filtered_tree_json = json.dumps(filtered_tree, indent=2)
86
+ return filtered_tree_json
87
+
88
+ def find_name_hierarchy(self, node, target_name):
89
+ if node.get('name') == target_name:
90
+ return [target_name]
91
+ if 'children' in node:
92
+ for child in node['children']:
93
+ result = self.find_name_hierarchy(child, target_name)
94
+ if result is not None:
95
+ if node.get('name'):
96
+ result.append(node['name'])
97
+ return result
98
+ return None
99
+
100
+ def rename_cluster(self, cluster_id, new_name):
101
+ print(f"Renaming cluster {cluster_id} to {new_name}")
102
+ def collect_names(node, names):
103
+ names.add(node.get('name'))
104
+ for child in node.get('children', []):
105
+ collect_names(child, names)
106
+ existing_names = set()
107
+ collect_names(self.Z_tree_format, existing_names)
108
+ unique_name = new_name
109
+ suffix = 1
110
+ while unique_name in existing_names:
111
+ unique_name = f"{new_name}_{suffix}"
112
+ suffix += 1
113
+
114
+ def rename_node(node):
115
+ if node.get('id') == cluster_id:
116
+ node['name'] = unique_name
117
+ for child in node.get('children', []):
118
+ rename_node(child)
119
+ rename_node(self.Z_tree_format)
120
+ return self.Z_tree_format
121
+
122
+ def get_node_name(self, node_id):
123
+ def find_name(node):
124
+ if node.get('id') == node_id:
125
+ return node.get('name')
126
+ for child in node.get('children', []):
127
+ result = find_name(child)
128
+ if result is not None:
129
+ return result
130
+ return None
131
+ return find_name(self.Z_tree_format)
@@ -0,0 +1,53 @@
1
+ import pandas as pd
2
+ import os
3
+
4
+ class EdgesDataframe:
5
+ def __init__(self, model_filename, edges_df_path):
6
+ """
7
+ Initialize EdgeDataframe handler.
8
+
9
+ Args:
10
+ model_filename: Name/path of the model
11
+ edges_df_path: Path to the edges dataframe CSV file
12
+ """
13
+ self.model_filename = model_filename
14
+ self.edges_df_path = edges_df_path
15
+ self.dataframe = None
16
+
17
+ def load_dataframe(self):
18
+ """Load the edges dataframe from CSV file."""
19
+ if not os.path.exists(self.edges_df_path):
20
+ raise FileNotFoundError(f"Edges dataframe not found at {self.edges_df_path}")
21
+
22
+ self.dataframe = pd.read_csv(self.edges_df_path)
23
+ return self.dataframe
24
+
25
+ def get_dataframe(self):
26
+ """Get the loaded dataframe."""
27
+ if self.dataframe is None:
28
+ raise ValueError("Dataframe not loaded. Call load_dataframe() first.")
29
+ return self.dataframe
30
+
31
+ def save_dataframe(self, path=None):
32
+ """Save the dataframe to CSV."""
33
+ if self.dataframe is None:
34
+ raise ValueError("No dataframe to save.")
35
+
36
+ save_path = path or self.edges_df_path
37
+ self.dataframe.to_csv(save_path, index=False)
38
+ return save_path
39
+
40
+ def filter_by_labels(self, source_labels=None, target_labels=None):
41
+ """Filter dataframe by source and/or target labels."""
42
+ if self.dataframe is None:
43
+ raise ValueError("Dataframe not loaded.")
44
+
45
+ filtered_df = self.dataframe.copy()
46
+
47
+ if source_labels:
48
+ filtered_df = filtered_df[filtered_df['source'].isin(source_labels)]
49
+
50
+ if target_labels:
51
+ filtered_df = filtered_df[filtered_df['target'].isin(target_labels)]
52
+
53
+ return filtered_df
File without changes
@@ -0,0 +1,28 @@
1
+ import numpy as np
2
+ import tensorflow as tf
3
+
4
+
5
+ class BatchPredictor:
6
+ def __init__(self, model, batch_size=32):
7
+ self.model = model
8
+ self.batch_size = batch_size
9
+ self.buffer_images = [] # To store images
10
+ self.buffer_labels = [] # To store corresponding labels
11
+ self.buffer_results = [] # To store batch results
12
+
13
+ def get_top_predictions(self, X, labels, top_k, graph_threshold):
14
+ batch_preds = self.model.predict(np.array(X))
15
+ batch_results = []
16
+ for pred in batch_preds:
17
+ top_indices = pred.argsort()[-top_k:][::-1]
18
+ valid_indices = [i for i in top_indices if i < len(labels)]
19
+
20
+ top_predictions = [
21
+ (i, labels[i], pred[i])
22
+ for i in valid_indices
23
+ if pred[i] >= graph_threshold
24
+ ]
25
+
26
+ batch_results.append(top_predictions)
27
+
28
+ return batch_results
@@ -0,0 +1,46 @@
1
+ from ...enums.explanation_method import ExplanationMethod
2
+
3
+ class GraphBuilder:
4
+ def __init__(self, graph_type, infinity):
5
+ self.graph_type = graph_type
6
+ self.infinity = infinity
7
+
8
+ def create_edge_weight(self, pred_prob):
9
+ if self.graph_type == ExplanationMethod.DISSIMILARITY.value:
10
+ return 1 - pred_prob
11
+ elif self.graph_type == ExplanationMethod.COUNT.value:
12
+ return 1
13
+ return pred_prob
14
+
15
+ def update_graph(self, graph, source_label, target_label, probability, image_id):
16
+ if source_label == target_label:
17
+ return None
18
+
19
+ weight = self.create_edge_weight(probability)
20
+
21
+ if graph.are_adjacent(source_label, target_label):
22
+ edge_id = graph.get_eid(source_label, target_label)
23
+ graph.es[edge_id]["weight"] += weight
24
+ else:
25
+ graph.add_edge(source_label, target_label, weight=weight)
26
+
27
+ edge_data = {
28
+ "image_id": image_id,
29
+ "source": source_label,
30
+ "target": target_label,
31
+ "target_probability": probability,
32
+ }
33
+
34
+ return edge_data
35
+
36
+
37
+ def add_infinity_edges(self, graph, infinity_edges_labels, label, source_label):
38
+ if label == source_label:
39
+ return
40
+
41
+ if label not in infinity_edges_labels:
42
+ if graph.are_adjacent(source_label, label):
43
+ edge_id = graph.get_eid(source_label, label)
44
+ graph.es[edge_id]["weight"] += self.infinity
45
+ else:
46
+ graph.add_edge(source_label, label, weight=self.infinity)
@@ -0,0 +1,30 @@
1
+ import heapq
2
+ import copy
3
+
4
+ from ...enums.heap_types import HeapType
5
+ from ...enums.explanation_method import ExplanationMethod
6
+
7
+ class HeapProcessor:
8
+ def __init__(self,graph, graph_type, labels):
9
+ self.heap_type = self._get_heap_type(graph_type)
10
+ self.heap = []
11
+ self.nodes_multiplier = -1 if self.heap_type == "max" else 1
12
+
13
+ self._process_edges(graph, labels)
14
+
15
+
16
+ def _get_heap_type(self, graph_type):
17
+ return HeapType.MINIMUM.value if graph_type == ExplanationMethod.DISSIMILARITY.value else HeapType.MAXIMUM.value
18
+
19
+ def _process_edges(self, graph, labels):
20
+ for edge in graph.es:
21
+ source = graph.vs[edge.source]["name"]
22
+ target = graph.vs[edge.target]["name"]
23
+ weight = edge["weight"] if "weight" in edge.attributes() else 0
24
+ heapq.heappush(self.heap, (self.nodes_multiplier * weight, source, target))
25
+
26
+ def get_heap(self):
27
+ return self.heap
28
+
29
+ def get_heap_copy(self):
30
+ return copy.deepcopy(self.heap)
@@ -0,0 +1,102 @@
1
+ import numpy as np
2
+ import json
3
+ import os
4
+ from .tree_node import TreeNode
5
+ from ...enums.heap_types import HeapType
6
+
7
+ class HierarchicalClusteringBuilder:
8
+ def __init__(self, heap_processor, labels):
9
+ self.nodes = {}
10
+ self.next_cluster_id = len(labels)
11
+ heap = heap_processor.get_heap()
12
+ self.max_weight = max((abs(weight) for weight,_, _ in heap), default=0)
13
+ self.forest = []
14
+ self._build(heap_processor, labels)
15
+
16
+ def _initialize_nodes(self, labels):
17
+ for i, label in enumerate(labels):
18
+ self.nodes[label] = TreeNode(node_id=i, label=label)
19
+
20
+ def _get_root(self, node):
21
+ while node.parent is not None:
22
+ node = node.parent
23
+ return node
24
+
25
+ def _build(self, heap_processor, labels):
26
+ self._initialize_nodes(labels)
27
+ n = len(labels)
28
+ distances = np.full((n, n), np.inf)
29
+ np.fill_diagonal(distances, 0)
30
+ label_to_idx = {label: i for i, label in enumerate(labels)}
31
+ for weight, source, target in heap_processor.get_heap():
32
+ if heap_processor.heap_type == HeapType.MAXIMUM.value:
33
+ distance = self.max_weight - (-1*weight)
34
+ else:
35
+ distance = weight
36
+ src_idx = label_to_idx[source]
37
+ tgt_idx = label_to_idx[target]
38
+ distances[src_idx, tgt_idx] = distance
39
+ distances[tgt_idx, src_idx] = distance
40
+ max_dist = np.max(distances[~np.isinf(distances)])
41
+ distances[np.isinf(distances)] = max_dist * 2
42
+ active_clusters = {i: [i] for i in range(n)}
43
+ next_id = n
44
+ while len(active_clusters) > 1:
45
+ min_dist = float('inf')
46
+ closest_pair = None
47
+ for i in active_clusters:
48
+ for j in active_clusters:
49
+ if i < j:
50
+ dist_sum = 0
51
+ count = 0
52
+ for idx1 in active_clusters[i]:
53
+ for idx2 in active_clusters[j]:
54
+ dist_sum += distances[idx1, idx2]
55
+ count += 1
56
+ avg_dist = dist_sum / count
57
+ if avg_dist < min_dist:
58
+ min_dist = avg_dist
59
+ closest_pair = (i, j)
60
+ i, j = closest_pair
61
+ if len(active_clusters[i]) == 1:
62
+ node_i = self.nodes[labels[active_clusters[i][0]]]
63
+ else:
64
+ node_i = self.nodes[f"cluster_{i}"]
65
+ if len(active_clusters[j]) == 1:
66
+ node_j = self.nodes[labels[active_clusters[j][0]]]
67
+ else:
68
+ node_j = self.nodes[f"cluster_{j}"]
69
+ cluster_name = f"cluster_{next_id}"
70
+ new_node = TreeNode(
71
+ node_id=next_id,
72
+ label=cluster_name,
73
+ children=[node_i, node_j],
74
+ weight=min_dist
75
+ )
76
+ self.nodes[cluster_name] = new_node
77
+ node_i.parent = new_node
78
+ node_j.parent = new_node
79
+ merged_cluster = active_clusters[i] + active_clusters[j]
80
+ active_clusters[next_id] = merged_cluster
81
+ del active_clusters[i]
82
+ del active_clusters[j]
83
+ next_id += 1
84
+ final_id = list(active_clusters.keys())[0]
85
+ if final_id < n:
86
+ self.final_tree = self.nodes[labels[final_id]]
87
+ else:
88
+ self.final_tree = self.nodes[f"cluster_{final_id}"]
89
+ self.forest = [self.final_tree]
90
+ print(f"Built a forest with {len(self.forest)} tree{'s' if len(self.forest) > 1 else ''}")
91
+
92
+
93
+ def get_forest(self):
94
+ return self.forest
95
+
96
+ def get_tree(self, index=0):
97
+ if 0 <= index < len(self.forest):
98
+ return self.forest[index]
99
+ return None
100
+
101
+ def forest_size(self):
102
+ return len(self.forest)
@@ -0,0 +1,71 @@
1
+ from nltk.tree import Tree
2
+
3
+ class TreeNode(Tree):
4
+ def __init__(self, node_id, label=None, children=None, weight=0, parent=None):
5
+ self.node_name = label if label else f"cluster {node_id}"
6
+ super().__init__(self.node_name, children if children else [])
7
+ self.node_id = node_id
8
+ self.weight = weight
9
+ self.parent = parent
10
+ if children:
11
+ for child in children:
12
+ if isinstance(child, TreeNode):
13
+ child.parent = self
14
+
15
+ def __hash__(self):
16
+ return hash(self.node_id)
17
+
18
+ def __eq__(self, other):
19
+ if not isinstance(other, TreeNode):
20
+ return False
21
+ return self.node_id == other.node_id
22
+
23
+ def add_child(self, child):
24
+ if isinstance(child, TreeNode):
25
+ child.parent = self
26
+ self.append(child)
27
+ return self
28
+
29
+ @classmethod
30
+ def create_parent(cls, node_id, children, label=None, weight=0):
31
+ parent = cls(node_id, label, weight=weight)
32
+ for child in children:
33
+ parent.add_child(child)
34
+ return parent
35
+
36
+ def get_root(self):
37
+ current = self
38
+ while current.parent is not None:
39
+ current = current.parent
40
+ return current
41
+
42
+ def get_path_to_root(self):
43
+ path = []
44
+ current = self
45
+ while current is not None:
46
+ path.append(current)
47
+ current = current.parent
48
+ return path
49
+
50
+ def find_lca(self, other_node):
51
+ if not isinstance(other_node, TreeNode):
52
+ raise TypeError("Expected TreeNode node")
53
+ path1 = []
54
+ current = self
55
+ while current is not None:
56
+ path1.append(current)
57
+ current = current.parent
58
+ path2 = []
59
+ current = other_node
60
+ while current is not None:
61
+ path2.append(current)
62
+ current = current.parent
63
+ path1.reverse()
64
+ path2.reverse()
65
+ lca = None
66
+ for i in range(min(len(path1), len(path2))):
67
+ if path1[i] is path2[i]:
68
+ lca = path1[i]
69
+ else:
70
+ break
71
+ return lca