compositional-explanations 0.0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. compositional_explanations-0.0.0.1/LICENSE +19 -0
  2. compositional_explanations-0.0.0.1/PKG-INFO +86 -0
  3. compositional_explanations-0.0.0.1/README.md +68 -0
  4. compositional_explanations-0.0.0.1/pyproject.toml +29 -0
  5. compositional_explanations-0.0.0.1/setup.cfg +4 -0
  6. compositional_explanations-0.0.0.1/src/compositional_explanations/__init__.py +0 -0
  7. compositional_explanations-0.0.0.1/src/compositional_explanations/beam.py +405 -0
  8. compositional_explanations-0.0.0.1/src/compositional_explanations/formula.py +418 -0
  9. compositional_explanations-0.0.0.1/src/compositional_explanations/metrics.py +65 -0
  10. compositional_explanations-0.0.0.1/src/compositional_explanations/optimal.py +947 -0
  11. compositional_explanations-0.0.0.1/src/compositional_explanations/optimal_sample_heuristic.py +776 -0
  12. compositional_explanations-0.0.0.1/src/compositional_explanations/optimal_sum_heuristic.py +635 -0
  13. compositional_explanations-0.0.0.1/src/compositional_explanations/path_heuristic.py +675 -0
  14. compositional_explanations-0.0.0.1/src/compositional_explanations.egg-info/PKG-INFO +86 -0
  15. compositional_explanations-0.0.0.1/src/compositional_explanations.egg-info/SOURCES.txt +20 -0
  16. compositional_explanations-0.0.0.1/src/compositional_explanations.egg-info/dependency_links.txt +1 -0
  17. compositional_explanations-0.0.0.1/src/compositional_explanations.egg-info/requires.txt +3 -0
  18. compositional_explanations-0.0.0.1/src/compositional_explanations.egg-info/top_level.txt +2 -0
  19. compositional_explanations-0.0.0.1/src/utils/constants.py +30 -0
  20. compositional_explanations-0.0.0.1/src/utils/general_utils.py +28 -0
  21. compositional_explanations-0.0.0.1/src/utils/mask_utils.py +125 -0
  22. compositional_explanations-0.0.0.1/src/utils/optimal_utils.py +590 -0
@@ -0,0 +1,19 @@
1
+ Copyright (c) 2018 The Python Packaging Authority
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in all
11
+ copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
+ SOFTWARE.
@@ -0,0 +1,86 @@
1
+ Metadata-Version: 2.1
2
+ Name: compositional_explanations
3
+ Version: 0.0.0.1
4
+ Summary: A small example package
5
+ Author-email: Biagio La Rosa <bilarosa@ucsc.edu>
6
+ License: MIT
7
+ Project-URL: Homepage, https://github.com/aiea-lab/optimal-compositional-explanations
8
+ Project-URL: Issues, https://github.com/aiea-lab/optimal-compositional-explanations/issues
9
+ Platform: UNKNOWN
10
+ Classifier: Programming Language :: Python :: 3
11
+ Classifier: Operating System :: OS Independent
12
+ Requires-Python: >=3.8
13
+ Description-Content-Type: text/markdown
14
+ License-File: LICENSE
15
+ Requires-Dist: numpy>=1.22.2
16
+ Requires-Dist: scipy>=1.10.1
17
+ Requires-Dist: torch>=1.14.0
18
+
19
+
20
+ # Compositional Explanations Package
21
+ This package provides functions to compute compositional explanations for a given bitmap and a set of boolean masks. The package includes two main methods for computing explanations: optimal and beam search. Additionally, it provides metrics to evaluate the quality of the explanations.
22
+
23
+ The package assumes you are able to provide the following inputs:
24
+ - `bitmaps`: A boolean tensor representing the bitmap to be explained for a given neuron. The bitmap represents the activation of the neuron for a set of inputs. A cell is 1 if the neuron is activated for that feature or whole input, and 0 otherwise.
25
+ - `masks`: A list of boolean tensors representing the masks for each feature. Each mask represents whether a given concept/feature is annotated in the given position or samples. A cell is 1 if the feature is present in that position or sample, and 0 otherwise.
26
+ - `disjoint_info`: A boolean tensor representing the disjointness of the masks. The tensor is of shape (num_masks, num_masks) and indicates whether two masks are disjoint (i.e., they do not overlap). A cell is 1 if the two masks are disjoint, and 0 otherwise.
27
+ - `length`: An integer representing the maximum length of the explanation formula. The formula is a combination of masks that best explains the bitmap.
28
+ - `beam_size`: An integer representing the beam size for the beam search method. This parameter controls the number of candidate explanations to consider at each step of the search.
29
+ - `device`: A string representing the device to run the computations on. Highly suggested to use a GPU for faster computations.
30
+ - `cache_dir`: A string representing the directory to cache heuristic informaation for the given masks. This is useful for large datasets where computing the heuristics can be time-consuming. Default is None, which means no caching will be used.
31
+
32
+ Useful functions include:
33
+ - `optimal.compute_optimal_explanations(bitmaps, masks, disjoint_info, length, device, cache_dir)`: Computes the optimal explanation formula for the given bitmap and masks. This method can be slow (~1 hour per neuron) for high complexity scenarios.
34
+ - `beam.compute_beam_explanations(bitmaps, masks, disjoint_info, length, beam_size, device, cache_dir)`: Computes the explanation formula for the given bitmap and masks using a beam search method. This method is faster than the optimal method but cannot guarantee the optimality of the explanation. It is recommended to use this method for large datasets or when a faster computation is needed.
35
+ - `metrics.compute_iou_from_masks(formula, masks, bitmaps)`: Computes the Intersection over Union (IoU) metric for the given explanation formula, masks, and bitmap. The IoU metric measures the overlap between the explanation masks and the bitmap, providing a measure of how well the explanation captures the bitmap's activation.
36
+
37
+ ## Example usage of the compositional_explanations package
38
+
39
+ ```python
40
+ from compositional_explanations import beam, optimal, metrics
41
+ import torch
42
+
43
+ # Create 5 boolean random masks of shape (4, 4)
44
+ masks = [torch.randint(0, 2, (4, 4), dtype=torch.bool) for _ in range(5)]
45
+
46
+ concept_index = 0
47
+ for mask in masks:
48
+ print(f"Mask {concept_index}:")
49
+ print(mask)
50
+ concept_index += 1
51
+
52
+ # Create a random bitmap of shape (4, 4)
53
+ bitmap = torch.randint(0, 2, (4, 4), dtype=torch.bool)
54
+
55
+ print(f"Bitmap:")
56
+ print(bitmap)
57
+
58
+ # Compute disjoint matrix (inefficent, we suggest to compute them directly from the annotations)
59
+ disjoint_matrix = torch.ones((len(masks), len(masks)), dtype=torch.bool)
60
+ for i in range(len(masks)):
61
+ for j in range(i + 1, len(masks)):
62
+ disjoint_matrix[i, j] = not torch.any(masks[i] & masks[j])
63
+ disjoint_matrix[j, i] = disjoint_matrix[i, j]
64
+
65
+ print("Disjoint Matrix:")
66
+ print(disjoint_matrix)
67
+ # Compute optimal explanations
68
+ best_formula_optimal = optimal.compute_optimal_explanations(
69
+ bitmaps=bitmap, masks=masks, disjoint_info=disjoint_matrix, length=3, device=torch.device("cuda"), cache_dir=None
70
+ )
71
+ optimal_iou = metrics.compute_iou_from_masks(formula=best_formula_optimal, masks=masks, bitmaps=bitmap)
72
+
73
+ print("Optimal Explanation:")
74
+ print("Best formula:", best_formula_optimal)
75
+ print("Best IoU:", optimal_iou)
76
+
77
+ best_formula_beam = beam.compute_beam_explanations(
78
+ bitmaps=bitmap, masks=masks, disjoint_info=disjoint_matrix, length=3, beam_size=5, device=torch.device("cuda"), cache_dir=None
79
+ )
80
+ beam_iou = metrics.compute_iou_from_masks(formula=best_formula_beam, masks=masks, bitmaps=bitmap)
81
+ print("Beam Explanation:")
82
+ print("Best formula:", best_formula_beam)
83
+ print("Best IoU:", beam_iou)
84
+ print()
85
+ print("Beam Explanation is optimal:", beam_iou == optimal_iou)
86
+ ```
@@ -0,0 +1,68 @@
1
+
2
+ # Compositional Explanations Package
3
+ This package provides functions to compute compositional explanations for a given bitmap and a set of boolean masks. The package includes two main methods for computing explanations: optimal and beam search. Additionally, it provides metrics to evaluate the quality of the explanations.
4
+
5
+ The package assumes you are able to provide the following inputs:
6
+ - `bitmaps`: A boolean tensor representing the bitmap to be explained for a given neuron. The bitmap represents the activation of the neuron for a set of inputs. A cell is 1 if the neuron is activated for that feature or whole input, and 0 otherwise.
7
+ - `masks`: A list of boolean tensors representing the masks for each feature. Each mask represents whether a given concept/feature is annotated in the given position or samples. A cell is 1 if the feature is present in that position or sample, and 0 otherwise.
8
+ - `disjoint_info`: A boolean tensor representing the disjointness of the masks. The tensor is of shape (num_masks, num_masks) and indicates whether two masks are disjoint (i.e., they do not overlap). A cell is 1 if the two masks are disjoint, and 0 otherwise.
9
+ - `length`: An integer representing the maximum length of the explanation formula. The formula is a combination of masks that best explains the bitmap.
10
+ - `beam_size`: An integer representing the beam size for the beam search method. This parameter controls the number of candidate explanations to consider at each step of the search.
11
+ - `device`: A string representing the device to run the computations on. Highly suggested to use a GPU for faster computations.
12
+ - `cache_dir`: A string representing the directory to cache heuristic informaation for the given masks. This is useful for large datasets where computing the heuristics can be time-consuming. Default is None, which means no caching will be used.
13
+
14
+ Useful functions include:
15
+ - `optimal.compute_optimal_explanations(bitmaps, masks, disjoint_info, length, device, cache_dir)`: Computes the optimal explanation formula for the given bitmap and masks. This method can be slow (~1 hour per neuron) for high complexity scenarios.
16
+ - `beam.compute_beam_explanations(bitmaps, masks, disjoint_info, length, beam_size, device, cache_dir)`: Computes the explanation formula for the given bitmap and masks using a beam search method. This method is faster than the optimal method but cannot guarantee the optimality of the explanation. It is recommended to use this method for large datasets or when a faster computation is needed.
17
+ - `metrics.compute_iou_from_masks(formula, masks, bitmaps)`: Computes the Intersection over Union (IoU) metric for the given explanation formula, masks, and bitmap. The IoU metric measures the overlap between the explanation masks and the bitmap, providing a measure of how well the explanation captures the bitmap's activation.
18
+
19
+ ## Example usage of the compositional_explanations package
20
+
21
+ ```python
22
+ from compositional_explanations import beam, optimal, metrics
23
+ import torch
24
+
25
+ # Create 5 boolean random masks of shape (4, 4)
26
+ masks = [torch.randint(0, 2, (4, 4), dtype=torch.bool) for _ in range(5)]
27
+
28
+ concept_index = 0
29
+ for mask in masks:
30
+ print(f"Mask {concept_index}:")
31
+ print(mask)
32
+ concept_index += 1
33
+
34
+ # Create a random bitmap of shape (4, 4)
35
+ bitmap = torch.randint(0, 2, (4, 4), dtype=torch.bool)
36
+
37
+ print(f"Bitmap:")
38
+ print(bitmap)
39
+
40
+ # Compute disjoint matrix (inefficent, we suggest to compute them directly from the annotations)
41
+ disjoint_matrix = torch.ones((len(masks), len(masks)), dtype=torch.bool)
42
+ for i in range(len(masks)):
43
+ for j in range(i + 1, len(masks)):
44
+ disjoint_matrix[i, j] = not torch.any(masks[i] & masks[j])
45
+ disjoint_matrix[j, i] = disjoint_matrix[i, j]
46
+
47
+ print("Disjoint Matrix:")
48
+ print(disjoint_matrix)
49
+ # Compute optimal explanations
50
+ best_formula_optimal = optimal.compute_optimal_explanations(
51
+ bitmaps=bitmap, masks=masks, disjoint_info=disjoint_matrix, length=3, device=torch.device("cuda"), cache_dir=None
52
+ )
53
+ optimal_iou = metrics.compute_iou_from_masks(formula=best_formula_optimal, masks=masks, bitmaps=bitmap)
54
+
55
+ print("Optimal Explanation:")
56
+ print("Best formula:", best_formula_optimal)
57
+ print("Best IoU:", optimal_iou)
58
+
59
+ best_formula_beam = beam.compute_beam_explanations(
60
+ bitmaps=bitmap, masks=masks, disjoint_info=disjoint_matrix, length=3, beam_size=5, device=torch.device("cuda"), cache_dir=None
61
+ )
62
+ beam_iou = metrics.compute_iou_from_masks(formula=best_formula_beam, masks=masks, bitmaps=bitmap)
63
+ print("Beam Explanation:")
64
+ print("Best formula:", best_formula_beam)
65
+ print("Best IoU:", beam_iou)
66
+ print()
67
+ print("Beam Explanation is optimal:", beam_iou == optimal_iou)
68
+ ```
@@ -0,0 +1,29 @@
1
+ [build-system]
2
+ requires = ["setuptools"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "compositional_explanations"
7
+ version = "0.0.0.1"
8
+ authors = [
9
+ { name="Biagio La Rosa", email="bilarosa@ucsc.edu" },
10
+ ]
11
+ description = "A small example package"
12
+ readme = "README.md"
13
+ requires-python = ">=3.8"
14
+ classifiers = [
15
+ "Programming Language :: Python :: 3",
16
+ "Operating System :: OS Independent",
17
+ ]
18
+
19
+ dependencies = [
20
+ "numpy>=1.22.2",
21
+ "scipy>=1.10.1",
22
+ "torch>=1.14.0",
23
+ ]
24
+
25
+ license = {text = "MIT"}
26
+
27
+ [project.urls]
28
+ Homepage = "https://github.com/aiea-lab/optimal-compositional-explanations"
29
+ Issues = "https://github.com/aiea-lab/optimal-compositional-explanations/issues"
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,405 @@
1
+ """Beam-search variant guided by the optimal heuristic family."""
2
+
3
+ from collections import Counter
4
+ import queue as Q
5
+
6
+ import torch
7
+
8
+ from compositional_explanations import formula as F
9
+ from compositional_explanations import optimal_sample_heuristic, metrics
10
+ from utils import optimal_utils, mask_utils
11
+
12
+ def beam_search(
13
+ search_space,
14
+ *,
15
+ masks,
16
+ beam_masks,
17
+ bitmaps,
18
+ beam_limit,
19
+ previous_beam=None,
20
+ ):
21
+ """Perform the beam search on the search space.
22
+
23
+ Args:
24
+ search_space (list): A list of formulas.
25
+ masks (dict): A dictionary of concept masks. Each mask is a tensor of
26
+ shape (N, H, W).
27
+ beam_masks (dict): A dictionary of cached formula masks for formulas in
28
+ the current beam.
29
+ bitmaps (torch.Tensor): A tensor of shape (N, H, W) where N is the
30
+ number of samples.
31
+ beam_limit (int): The beam size.
32
+ previous_beam (dict): A dictionary mapping beam formulas to their IoU.
33
+
34
+ Returns:
35
+ current_beam_formulas (list): A list of formulas.
36
+ current_beam_iou (list): A list of IoU values corresponding to
37
+ `current_beam_formulas`.
38
+ visited_indices (int): The number of formulas whose exact IoU was evaluated.
39
+ """
40
+ if previous_beam is None:
41
+ previous_beam = {}
42
+ current_beam = Q.PriorityQueue(beam_limit)
43
+ current_beam_iou = []
44
+ current_beam_formulas = []
45
+ minimum = 0
46
+ visited_indices = 0
47
+ best_formula = None
48
+ # Init beam with previous best
49
+ for label, iou in previous_beam.items():
50
+ if not current_beam.full():
51
+ current_beam.put((iou, label))
52
+ minimum = current_beam.queue[0][0]
53
+ elif iou > minimum:
54
+ current_beam.get()
55
+ current_beam.put((iou, label))
56
+ minimum = current_beam.queue[0][0]
57
+
58
+ # Set minimum
59
+ if current_beam.empty():
60
+ minimum = 0
61
+ else:
62
+ minimum = current_beam.queue[0][0]
63
+
64
+ # Iterate over the search space
65
+ for candidate_formula in search_space:
66
+ e_iou = candidate_formula.iou
67
+
68
+ # If the estimated IoU is less than the minimum, we can stop the search if the beam is full
69
+ if current_beam.full() and e_iou < minimum:
70
+ break
71
+
72
+ # skip equivalent formulas of the current beam
73
+ if best_formula and hash(candidate_formula) == hash(best_formula):
74
+ continue
75
+
76
+ # Compute IoU
77
+ masks_formula = mask_utils.get_formula_mask(
78
+ candidate_formula, masks, beam_masks, device=bitmaps.device
79
+ )
80
+ iou = metrics.iou(masks_formula, bitmaps)
81
+
82
+ # Update visited nodes
83
+ visited_indices += 1
84
+
85
+ # Update beam
86
+ if not current_beam.full():
87
+ candidate_formula.iou = iou
88
+ current_beam.put((iou, candidate_formula))
89
+ minimum = current_beam.queue[0][0]
90
+ elif iou > minimum:
91
+ candidate_formula.iou = iou
92
+ current_beam.get()
93
+ current_beam.put((iou, candidate_formula))
94
+ minimum = current_beam.queue[0][0]
95
+
96
+ # Extract formulas and iou from the beam
97
+ for _ in range(current_beam.qsize()):
98
+ iou, candidate = current_beam.get()
99
+ current_beam_formulas.append(candidate)
100
+ current_beam_iou.append(iou)
101
+ return current_beam_formulas, current_beam_iou, visited_indices
102
+
103
+
104
+ def compute_next_search_space(formulas, candidate_labels):
105
+ """Compute the next search space starting from the current beam
106
+ of formulas.
107
+
108
+ Args:
109
+ formulas (list): A list of formulas.
110
+ candidate_labels (list): A list of candidate labels.
111
+
112
+ Returns:
113
+ search_space (list): A list of formulas.
114
+ """
115
+ search_space = []
116
+
117
+ for formula in formulas:
118
+ vals_formula = set(formula.get_vals())
119
+ for candidate_term in candidate_labels:
120
+ # remove dummy cases with void masks or equivalent formulas
121
+ if candidate_term.val in vals_formula:
122
+ continue
123
+ for op, negate in [(F.Or, False), (F.And, False), (F.And, True)]:
124
+ candidate_to_attach = candidate_term
125
+ if negate:
126
+ candidate_to_attach = F.Not(candidate_to_attach)
127
+ candidate_formula = op(formula, candidate_to_attach)
128
+ candidate_formula.iou = 1.0
129
+
130
+ search_space.append(candidate_formula)
131
+ search_space = list(set(search_space))
132
+ return search_space
133
+
134
+
135
+ def compute_beam_quantities(beam, masks, beam_masks, heuristic_info, bitmaps):
136
+ """Computes the quantities for each formula in the beam.
137
+ Args:
138
+ beam (dict): A dictionary containing the formulas in the beam and their corresponding iou.
139
+ masks (list): A list of concept masks. Each mask is a tensor of shape (H*W).
140
+ beam_masks (dict): A dictionary containing the masks for each formula in the beam.
141
+ heuristic_info (tuple): A tuple containing the masks_info, neuron_quantities, and concept quantities for the optimal heuristic.
142
+ bitmaps (torch.Tensor): A tensor of shape (N, H*W) where N is the number of sample.
143
+ Returns:
144
+ tuple: A tuple containing the beam info and the updated beam masks.
145
+ """
146
+ seg_quantities, neuron_quantities, _ = heuristic_info
147
+ (neuron_unique, _), (neuron_common, _), _, _, _, _ = neuron_quantities
148
+ common_elements, unique_elements, _ = seg_quantities
149
+
150
+ # Elements re-used by all the nodes, we pre-move them to gpu if available
151
+ unique_elements = unique_elements.to(bitmaps.device)
152
+ common_elements = common_elements.to(bitmaps.device)
153
+ beam_info = {}
154
+ for label in beam:
155
+ if label in beam_masks or isinstance(label, F.Leaf):
156
+ # We already have the information for this label, we skip it
157
+ continue
158
+
159
+ # Compute the mask for the label
160
+ label_mask = mask_utils.get_formula_mask(label, masks, beam_masks, device=bitmaps.device)
161
+
162
+ # Compute the quantities from the mask
163
+ label_quantities = optimal_utils.compute_quantities_vector(
164
+ label_mask=label_mask,
165
+ bitmaps=bitmaps,
166
+ common_elements=common_elements,
167
+ unique_elements=unique_elements,
168
+ neuron_common=neuron_common,
169
+ neuron_unique=neuron_unique,
170
+ )
171
+
172
+ # Compute the info from the quantities
173
+ label_info = optimal_utils.get_concept_info(label_quantities)
174
+
175
+ beam_info[label] = label_info
176
+ label_mask = label_mask.cpu()
177
+ beam_masks[label] = label_mask
178
+ return beam_info, beam_masks
179
+
180
+
181
+ def get_beam_info(
182
+ *, beam, masks, beam_masks, heuristic_info, label_mapping, bitmaps, length
183
+ ):
184
+ """Gets the information for each formula in the beam and updates the heuristic information and the label mapping.
185
+ Args:
186
+ beam (dict): A dictionary containing the formulas in the beam and their corresponding iou.
187
+ masks (list): A list of concept masks. Each mask is a tensor of shape (H*W).
188
+ beam_masks (dict): A dictionary containing the masks for each formula in the beam.
189
+ heuristic_info (tuple): A tuple containing the masks_info, neuron_quantities, and concept quantities for the optimal heuristic.
190
+ label_mapping (dict): A dictionary mapping labels to the indices of their corresponding masks.
191
+ bitmaps (torch.Tensor): A tensor of shape (N, H*W) where N is the number of sample.
192
+ length (int): The maximum length of the formulas to search.
193
+ Returns:
194
+ tuple: A tuple containing the updated beam masks and the updated heuristic information and label mapping.
195
+ """
196
+ beam_info, beam_masks = compute_beam_quantities(
197
+ beam, masks, beam_masks, heuristic_info, bitmaps
198
+ )
199
+ new_heuristic_info, new_label_mapping = optimal_utils.update_heuristic_info(
200
+ nodes_info=beam_info,
201
+ heuristic_info=heuristic_info,
202
+ label_mapping=label_mapping,
203
+ max_length=length,
204
+ )
205
+ return beam_masks, (new_heuristic_info, new_label_mapping)
206
+
207
+
208
+ def sort_search_space_by(
209
+ *,
210
+ search_space,
211
+ label_mapping,
212
+ heuristic_info,
213
+ disjoint_info,
214
+ num_hits,
215
+ max_size_mask
216
+ ):
217
+ """Sorts the search space based on the estimated iou for each formula in the search space.
218
+ Args:
219
+ search_space (list): A list of formulas to sort.
220
+ label_mapping (dict): A dictionary mapping labels to the indices of their corresponding masks.
221
+ heuristic_info (tuple): A tuple containing the masks_info, neuron_quantities, and concept quantities for the optimal heuristic.
222
+ disjoint_info (dict): A dictionary containing the disjoint information for the concepts.
223
+ num_hits (int): The number of hits in the bitmaps.
224
+ max_size_mask (int): The maximum size of the mask.
225
+ Returns:
226
+ list: A sorted list of formulas based on the estimated iou.
227
+ """
228
+ _, neuron_quantities, _ = heuristic_info
229
+ for index_formula, candidate_formula in enumerate(search_space):
230
+ label_quantities = optimal_utils.estimate_label_quantities(
231
+ heuristic=optimal_sample_heuristic,
232
+ label=candidate_formula,
233
+ label_mapping=label_mapping,
234
+ heuristic_info=heuristic_info,
235
+ max_size_mask=max_size_mask,
236
+ disjoint_info=disjoint_info,
237
+ )
238
+ if label_quantities is None:
239
+ # Label discarded at the previous step
240
+ esti_iou = 0.0
241
+ else:
242
+
243
+ esti_iou = optimal_utils.compute_max_iou_from_label_info(
244
+ label_quantities=label_quantities,
245
+ num_hits=num_hits,
246
+ neuron_quantities=neuron_quantities,
247
+ )
248
+ search_space[index_formula].iou = esti_iou
249
+
250
+ # Sort the search space based on the estimated iou in descending order
251
+ search_space = sorted(search_space, key=lambda x: x.iou, reverse=True)
252
+ return search_space
253
+
254
+
255
+ def perform_search(
256
+ *,
257
+ masks,
258
+ bitmaps,
259
+ heuristic_info,
260
+ disjoint_info,
261
+ max_size_mask,
262
+ beam_size=5,
263
+ length=3
264
+ ):
265
+ """Performs the beam optimal search to find the best formula that explains the bitmaps given the masks and the heuristic information.
266
+ Args:
267
+ masks (list): A list of concept masks. Each mask is a tensor of shape (H*W).
268
+ bitmaps (torch.Tensor): A tensor of shape (N, H*W) where N is the number of sample.
269
+ heuristic_info (tuple): A tuple containing the masks_info, neuron_quantities, and concept quantities for the optimal heuristic.
270
+ disjoint_info (dict): A dictionary containing the disjoint information for the concepts.
271
+ max_size_mask (int): The maximum size of the mask.
272
+ beam_size (int): The beam size for the search.
273
+ length (int): The maximum length of the formulas to search.
274
+ Returns:
275
+ tuple: A tuple containing the best formula, its iou, the total number of visited nodes, the number of expanded nodes, and the number of estimated nodes.
276
+ """
277
+
278
+ # Number of hits
279
+ num_hits = bitmaps.sum()
280
+
281
+ # Utilities
282
+ label_mapping = {}
283
+ leaf_mapping = {F.Leaf(c): c for c in range(len(masks))}
284
+ label_mapping.update(leaf_mapping)
285
+
286
+ # Extract first beam and candidate concepts
287
+ candidate_labels = [F.Leaf(c) for c in range(len(masks))]
288
+
289
+ _, neuron_quantities, concept_quantities = heuristic_info
290
+ iou_atoms = {
291
+ k: optimal_utils.compute_max_iou_from_label_info(
292
+ concept_quantities[label_mapping[k]], neuron_quantities, num_hits
293
+ )
294
+ for k in candidate_labels
295
+ }
296
+ iou_atoms = Counter(iou_atoms)
297
+ first_beam_num = min(len(iou_atoms), beam_size * 2)
298
+ beam_atoms = {
299
+ lab: iou for lab, iou in iou_atoms.most_common(first_beam_num) if iou > 0
300
+ }
301
+
302
+ if len(beam_atoms) == 0:
303
+ return None
304
+
305
+ # Beam Search
306
+ beam_masks = {}
307
+ beam = beam_atoms.copy()
308
+ for previous_beam_length in range(1, length):
309
+ # Only expand formulas of the previous beam length to avoid regenerating beam tree already explored
310
+ to_expand = [lab for lab in beam.keys() if len(lab) == previous_beam_length]
311
+ search_space = compute_next_search_space(
312
+ to_expand,
313
+ candidate_labels,
314
+ )
315
+
316
+ # Compute the estimation for the next frontier
317
+ sorted_search_space = sort_search_space_by(
318
+ search_space=search_space,
319
+ label_mapping=label_mapping,
320
+ heuristic_info=heuristic_info,
321
+ num_hits=num_hits,
322
+ max_size_mask=max_size_mask,
323
+ disjoint_info=disjoint_info,
324
+ )
325
+
326
+ # # If we are in the last iteration, we set the beam size to 1 to get only the best formula
327
+ if previous_beam_length == length - 1:
328
+ beam_size = 1
329
+
330
+ # Perform beam search
331
+ next_beam_formulas, next_beam_iou, beam_visited = beam_search(
332
+ sorted_search_space,
333
+ masks=masks,
334
+ previous_beam=beam,
335
+ beam_masks=beam_masks,
336
+ bitmaps=bitmaps,
337
+ beam_limit=beam_size,
338
+ )
339
+
340
+ # Update top formulas
341
+ for index_beam in range(len(next_beam_formulas)):
342
+ beam.update({next_beam_formulas[index_beam]: next_beam_iou[index_beam]})
343
+
344
+ # Trim the beam
345
+ beam = dict(Counter(beam).most_common(beam_size))
346
+
347
+ # Update infos if there are step left
348
+ if previous_beam_length < length - 1:
349
+ beam_masks, (heuristic_info, label_mapping) = get_beam_info(
350
+ beam=beam.keys(),
351
+ masks=masks,
352
+ beam_masks=beam_masks,
353
+ heuristic_info=heuristic_info,
354
+ label_mapping=label_mapping,
355
+ bitmaps=bitmaps,
356
+ length=length,
357
+ )
358
+
359
+ top_result = Counter(beam).most_common(1)[0]
360
+
361
+ best_label = top_result[0]
362
+ return best_label
363
+
364
+
365
+ def compute_beam_explanations(
366
+ *, bitmaps, masks, disjoint_info, length, beam_size, device, cache_dir=None
367
+ ):
368
+ """Computes the beam explanations for the given bitmaps and masks.
369
+ Args:
370
+ bitmaps (torch.Tensor): A tensor of shape (N, H*W) where N is the number of sample.
371
+ masks (list): A list of concept masks. Each mask is a tensor of shape (H*W).
372
+ disjoint_info (dict): A dictionary containing the disjoint information for the concepts.
373
+ length (int): The maximum length of the formulas to search.
374
+ device (torch.device): The device to perform the computations on.
375
+ cache_dir (str, optional): The directory to store the cached information for the masks. If None, the information is not cached and is recomputed every time. Defaults to None.
376
+ Returns:
377
+ tuple: A tuple containing the best formula, its iou, the total number of visited nodes, the number of expanded nodes, and the number of estimated nodes.
378
+ """
379
+
380
+ max_size_mask = torch.numel(masks[0])
381
+ bitmaps = bitmaps.to(device)
382
+
383
+ # Get masks info
384
+ masks_info = mask_utils.get_dataset_info(masks, cache_dir)
385
+
386
+ # Get quantities
387
+ neuron_quantities, concept_quantities = optimal_utils.get_optimal_heuristic_info(
388
+ masks=masks, bitmaps=bitmaps, masks_quantities=masks_info
389
+ )
390
+
391
+ heuristic_info = (
392
+ masks_info,
393
+ neuron_quantities,
394
+ concept_quantities,
395
+ )
396
+ best_label = perform_search(
397
+ masks=masks,
398
+ bitmaps=bitmaps,
399
+ heuristic_info=heuristic_info,
400
+ disjoint_info=disjoint_info,
401
+ max_size_mask=max_size_mask,
402
+ beam_size=beam_size,
403
+ length=length,
404
+ )
405
+ return best_label