compositional-explanations 0.0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compositional_explanations-0.0.0.1/LICENSE +19 -0
- compositional_explanations-0.0.0.1/PKG-INFO +86 -0
- compositional_explanations-0.0.0.1/README.md +68 -0
- compositional_explanations-0.0.0.1/pyproject.toml +29 -0
- compositional_explanations-0.0.0.1/setup.cfg +4 -0
- compositional_explanations-0.0.0.1/src/compositional_explanations/__init__.py +0 -0
- compositional_explanations-0.0.0.1/src/compositional_explanations/beam.py +405 -0
- compositional_explanations-0.0.0.1/src/compositional_explanations/formula.py +418 -0
- compositional_explanations-0.0.0.1/src/compositional_explanations/metrics.py +65 -0
- compositional_explanations-0.0.0.1/src/compositional_explanations/optimal.py +947 -0
- compositional_explanations-0.0.0.1/src/compositional_explanations/optimal_sample_heuristic.py +776 -0
- compositional_explanations-0.0.0.1/src/compositional_explanations/optimal_sum_heuristic.py +635 -0
- compositional_explanations-0.0.0.1/src/compositional_explanations/path_heuristic.py +675 -0
- compositional_explanations-0.0.0.1/src/compositional_explanations.egg-info/PKG-INFO +86 -0
- compositional_explanations-0.0.0.1/src/compositional_explanations.egg-info/SOURCES.txt +20 -0
- compositional_explanations-0.0.0.1/src/compositional_explanations.egg-info/dependency_links.txt +1 -0
- compositional_explanations-0.0.0.1/src/compositional_explanations.egg-info/requires.txt +3 -0
- compositional_explanations-0.0.0.1/src/compositional_explanations.egg-info/top_level.txt +2 -0
- compositional_explanations-0.0.0.1/src/utils/constants.py +30 -0
- compositional_explanations-0.0.0.1/src/utils/general_utils.py +28 -0
- compositional_explanations-0.0.0.1/src/utils/mask_utils.py +125 -0
- compositional_explanations-0.0.0.1/src/utils/optimal_utils.py +590 -0
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
Copyright (c) 2018 The Python Packaging Authority
|
|
2
|
+
|
|
3
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
4
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
5
|
+
in the Software without restriction, including without limitation the rights
|
|
6
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
7
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
8
|
+
furnished to do so, subject to the following conditions:
|
|
9
|
+
|
|
10
|
+
The above copyright notice and this permission notice shall be included in all
|
|
11
|
+
copies or substantial portions of the Software.
|
|
12
|
+
|
|
13
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
14
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
15
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
16
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
17
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
18
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
19
|
+
SOFTWARE.
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: compositional_explanations
|
|
3
|
+
Version: 0.0.0.1
|
|
4
|
+
Summary: A small example package
|
|
5
|
+
Author-email: Biagio La Rosa <bilarosa@ucsc.edu>
|
|
6
|
+
License: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/aiea-lab/optimal-compositional-explanations
|
|
8
|
+
Project-URL: Issues, https://github.com/aiea-lab/optimal-compositional-explanations/issues
|
|
9
|
+
Platform: UNKNOWN
|
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
|
11
|
+
Classifier: Operating System :: OS Independent
|
|
12
|
+
Requires-Python: >=3.8
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
License-File: LICENSE
|
|
15
|
+
Requires-Dist: numpy>=1.22.2
|
|
16
|
+
Requires-Dist: scipy>=1.10.1
|
|
17
|
+
Requires-Dist: torch>=1.14.0
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# Compositional Explanations Package
|
|
21
|
+
This package provides functions to compute compositional explanations for a given bitmap and a set of boolean masks. The package includes two main methods for computing explanations: optimal and beam search. Additionally, it provides metrics to evaluate the quality of the explanations.
|
|
22
|
+
|
|
23
|
+
The package assumes you are able to provide the following inputs:
|
|
24
|
+
- `bitmaps`: A boolean tensor representing the bitmap to be explained for a given neuron. The bitmap represents the activation of the neuron for a set of inputs. A cell is 1 if the neuron is activated for that feature or whole input, and 0 otherwise.
|
|
25
|
+
- `masks`: A list of boolean tensors representing the masks for each feature. Each mask represents whether a given concept/feature is annotated in the given position or samples. A cell is 1 if the feature is present in that position or sample, and 0 otherwise.
|
|
26
|
+
- `disjoint_info`: A boolean tensor representing the disjointness of the masks. The tensor is of shape (num_masks, num_masks) and indicates whether two masks are disjoint (i.e., they do not overlap). A cell is 1 if the two masks are disjoint, and 0 otherwise.
|
|
27
|
+
- `length`: An integer representing the maximum length of the explanation formula. The formula is a combination of masks that best explains the bitmap.
|
|
28
|
+
- `beam_size`: An integer representing the beam size for the beam search method. This parameter controls the number of candidate explanations to consider at each step of the search.
|
|
29
|
+
- `device`: A string representing the device to run the computations on. Highly suggested to use a GPU for faster computations.
|
|
30
|
+
- `cache_dir`: A string representing the directory to cache heuristic informaation for the given masks. This is useful for large datasets where computing the heuristics can be time-consuming. Default is None, which means no caching will be used.
|
|
31
|
+
|
|
32
|
+
Useful functions include:
|
|
33
|
+
- `optimal.compute_optimal_explanations(bitmaps, masks, disjoint_info, length, device, cache_dir)`: Computes the optimal explanation formula for the given bitmap and masks. This method can be slow (~1 hour per neuron) for high complexity scenarios.
|
|
34
|
+
- `beam.compute_beam_explanations(bitmaps, masks, disjoint_info, length, beam_size, device, cache_dir)`: Computes the explanation formula for the given bitmap and masks using a beam search method. This method is faster than the optimal method but cannot guarantee the optimality of the explanation. It is recommended to use this method for large datasets or when a faster computation is needed.
|
|
35
|
+
- `metrics.compute_iou_from_masks(formula, masks, bitmaps)`: Computes the Intersection over Union (IoU) metric for the given explanation formula, masks, and bitmap. The IoU metric measures the overlap between the explanation masks and the bitmap, providing a measure of how well the explanation captures the bitmap's activation.
|
|
36
|
+
|
|
37
|
+
## Example usage of the compositional_explanations package
|
|
38
|
+
|
|
39
|
+
```python
|
|
40
|
+
from compositional_explanations import beam, optimal, metrics
|
|
41
|
+
import torch
|
|
42
|
+
|
|
43
|
+
# Create 5 boolean random masks of shape (4, 4)
|
|
44
|
+
masks = [torch.randint(0, 2, (4, 4), dtype=torch.bool) for _ in range(5)]
|
|
45
|
+
|
|
46
|
+
concept_index = 0
|
|
47
|
+
for mask in masks:
|
|
48
|
+
print(f"Mask {concept_index}:")
|
|
49
|
+
print(mask)
|
|
50
|
+
concept_index += 1
|
|
51
|
+
|
|
52
|
+
# Create a random bitmap of shape (4, 4)
|
|
53
|
+
bitmap = torch.randint(0, 2, (4, 4), dtype=torch.bool)
|
|
54
|
+
|
|
55
|
+
print(f"Bitmap:")
|
|
56
|
+
print(bitmap)
|
|
57
|
+
|
|
58
|
+
# Compute disjoint matrix (inefficent, we suggest to compute them directly from the annotations)
|
|
59
|
+
disjoint_matrix = torch.ones((len(masks), len(masks)), dtype=torch.bool)
|
|
60
|
+
for i in range(len(masks)):
|
|
61
|
+
for j in range(i + 1, len(masks)):
|
|
62
|
+
disjoint_matrix[i, j] = not torch.any(masks[i] & masks[j])
|
|
63
|
+
disjoint_matrix[j, i] = disjoint_matrix[i, j]
|
|
64
|
+
|
|
65
|
+
print("Disjoint Matrix:")
|
|
66
|
+
print(disjoint_matrix)
|
|
67
|
+
# Compute optimal explanations
|
|
68
|
+
best_formula_optimal = optimal.compute_optimal_explanations(
|
|
69
|
+
bitmaps=bitmap, masks=masks, disjoint_info=disjoint_matrix, length=3, device=torch.device("cuda"), cache_dir=None
|
|
70
|
+
)
|
|
71
|
+
optimal_iou = metrics.compute_iou_from_masks(formula=best_formula_optimal, masks=masks, bitmaps=bitmap)
|
|
72
|
+
|
|
73
|
+
print("Optimal Explanation:")
|
|
74
|
+
print("Best formula:", best_formula_optimal)
|
|
75
|
+
print("Best IoU:", optimal_iou)
|
|
76
|
+
|
|
77
|
+
best_formula_beam = beam.compute_beam_explanations(
|
|
78
|
+
bitmaps=bitmap, masks=masks, disjoint_info=disjoint_matrix, length=3, beam_size=5, device=torch.device("cuda"), cache_dir=None
|
|
79
|
+
)
|
|
80
|
+
beam_iou = metrics.compute_iou_from_masks(formula=best_formula_beam, masks=masks, bitmaps=bitmap)
|
|
81
|
+
print("Beam Explanation:")
|
|
82
|
+
print("Best formula:", best_formula_beam)
|
|
83
|
+
print("Best IoU:", beam_iou)
|
|
84
|
+
print()
|
|
85
|
+
print("Beam Explanation is optimal:", beam_iou == optimal_iou)
|
|
86
|
+
```
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
|
|
2
|
+
# Compositional Explanations Package
|
|
3
|
+
This package provides functions to compute compositional explanations for a given bitmap and a set of boolean masks. The package includes two main methods for computing explanations: optimal and beam search. Additionally, it provides metrics to evaluate the quality of the explanations.
|
|
4
|
+
|
|
5
|
+
The package assumes you are able to provide the following inputs:
|
|
6
|
+
- `bitmaps`: A boolean tensor representing the bitmap to be explained for a given neuron. The bitmap represents the activation of the neuron for a set of inputs. A cell is 1 if the neuron is activated for that feature or whole input, and 0 otherwise.
|
|
7
|
+
- `masks`: A list of boolean tensors representing the masks for each feature. Each mask represents whether a given concept/feature is annotated in the given position or samples. A cell is 1 if the feature is present in that position or sample, and 0 otherwise.
|
|
8
|
+
- `disjoint_info`: A boolean tensor representing the disjointness of the masks. The tensor is of shape (num_masks, num_masks) and indicates whether two masks are disjoint (i.e., they do not overlap). A cell is 1 if the two masks are disjoint, and 0 otherwise.
|
|
9
|
+
- `length`: An integer representing the maximum length of the explanation formula. The formula is a combination of masks that best explains the bitmap.
|
|
10
|
+
- `beam_size`: An integer representing the beam size for the beam search method. This parameter controls the number of candidate explanations to consider at each step of the search.
|
|
11
|
+
- `device`: A string representing the device to run the computations on. Highly suggested to use a GPU for faster computations.
|
|
12
|
+
- `cache_dir`: A string representing the directory to cache heuristic informaation for the given masks. This is useful for large datasets where computing the heuristics can be time-consuming. Default is None, which means no caching will be used.
|
|
13
|
+
|
|
14
|
+
Useful functions include:
|
|
15
|
+
- `optimal.compute_optimal_explanations(bitmaps, masks, disjoint_info, length, device, cache_dir)`: Computes the optimal explanation formula for the given bitmap and masks. This method can be slow (~1 hour per neuron) for high complexity scenarios.
|
|
16
|
+
- `beam.compute_beam_explanations(bitmaps, masks, disjoint_info, length, beam_size, device, cache_dir)`: Computes the explanation formula for the given bitmap and masks using a beam search method. This method is faster than the optimal method but cannot guarantee the optimality of the explanation. It is recommended to use this method for large datasets or when a faster computation is needed.
|
|
17
|
+
- `metrics.compute_iou_from_masks(formula, masks, bitmaps)`: Computes the Intersection over Union (IoU) metric for the given explanation formula, masks, and bitmap. The IoU metric measures the overlap between the explanation masks and the bitmap, providing a measure of how well the explanation captures the bitmap's activation.
|
|
18
|
+
|
|
19
|
+
## Example usage of the compositional_explanations package
|
|
20
|
+
|
|
21
|
+
```python
|
|
22
|
+
from compositional_explanations import beam, optimal, metrics
|
|
23
|
+
import torch
|
|
24
|
+
|
|
25
|
+
# Create 5 boolean random masks of shape (4, 4)
|
|
26
|
+
masks = [torch.randint(0, 2, (4, 4), dtype=torch.bool) for _ in range(5)]
|
|
27
|
+
|
|
28
|
+
concept_index = 0
|
|
29
|
+
for mask in masks:
|
|
30
|
+
print(f"Mask {concept_index}:")
|
|
31
|
+
print(mask)
|
|
32
|
+
concept_index += 1
|
|
33
|
+
|
|
34
|
+
# Create a random bitmap of shape (4, 4)
|
|
35
|
+
bitmap = torch.randint(0, 2, (4, 4), dtype=torch.bool)
|
|
36
|
+
|
|
37
|
+
print(f"Bitmap:")
|
|
38
|
+
print(bitmap)
|
|
39
|
+
|
|
40
|
+
# Compute disjoint matrix (inefficent, we suggest to compute them directly from the annotations)
|
|
41
|
+
disjoint_matrix = torch.ones((len(masks), len(masks)), dtype=torch.bool)
|
|
42
|
+
for i in range(len(masks)):
|
|
43
|
+
for j in range(i + 1, len(masks)):
|
|
44
|
+
disjoint_matrix[i, j] = not torch.any(masks[i] & masks[j])
|
|
45
|
+
disjoint_matrix[j, i] = disjoint_matrix[i, j]
|
|
46
|
+
|
|
47
|
+
print("Disjoint Matrix:")
|
|
48
|
+
print(disjoint_matrix)
|
|
49
|
+
# Compute optimal explanations
|
|
50
|
+
best_formula_optimal = optimal.compute_optimal_explanations(
|
|
51
|
+
bitmaps=bitmap, masks=masks, disjoint_info=disjoint_matrix, length=3, device=torch.device("cuda"), cache_dir=None
|
|
52
|
+
)
|
|
53
|
+
optimal_iou = metrics.compute_iou_from_masks(formula=best_formula_optimal, masks=masks, bitmaps=bitmap)
|
|
54
|
+
|
|
55
|
+
print("Optimal Explanation:")
|
|
56
|
+
print("Best formula:", best_formula_optimal)
|
|
57
|
+
print("Best IoU:", optimal_iou)
|
|
58
|
+
|
|
59
|
+
best_formula_beam = beam.compute_beam_explanations(
|
|
60
|
+
bitmaps=bitmap, masks=masks, disjoint_info=disjoint_matrix, length=3, beam_size=5, device=torch.device("cuda"), cache_dir=None
|
|
61
|
+
)
|
|
62
|
+
beam_iou = metrics.compute_iou_from_masks(formula=best_formula_beam, masks=masks, bitmaps=bitmap)
|
|
63
|
+
print("Beam Explanation:")
|
|
64
|
+
print("Best formula:", best_formula_beam)
|
|
65
|
+
print("Best IoU:", beam_iou)
|
|
66
|
+
print()
|
|
67
|
+
print("Beam Explanation is optimal:", beam_iou == optimal_iou)
|
|
68
|
+
```
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "compositional_explanations"
|
|
7
|
+
version = "0.0.0.1"
|
|
8
|
+
authors = [
|
|
9
|
+
{ name="Biagio La Rosa", email="bilarosa@ucsc.edu" },
|
|
10
|
+
]
|
|
11
|
+
description = "A small example package"
|
|
12
|
+
readme = "README.md"
|
|
13
|
+
requires-python = ">=3.8"
|
|
14
|
+
classifiers = [
|
|
15
|
+
"Programming Language :: Python :: 3",
|
|
16
|
+
"Operating System :: OS Independent",
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
dependencies = [
|
|
20
|
+
"numpy>=1.22.2",
|
|
21
|
+
"scipy>=1.10.1",
|
|
22
|
+
"torch>=1.14.0",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
license = {text = "MIT"}
|
|
26
|
+
|
|
27
|
+
[project.urls]
|
|
28
|
+
Homepage = "https://github.com/aiea-lab/optimal-compositional-explanations"
|
|
29
|
+
Issues = "https://github.com/aiea-lab/optimal-compositional-explanations/issues"
|
|
File without changes
|
|
@@ -0,0 +1,405 @@
|
|
|
1
|
+
"""Beam-search variant guided by the optimal heuristic family."""
|
|
2
|
+
|
|
3
|
+
from collections import Counter
|
|
4
|
+
import queue as Q
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from compositional_explanations import formula as F
|
|
9
|
+
from compositional_explanations import optimal_sample_heuristic, metrics
|
|
10
|
+
from utils import optimal_utils, mask_utils
|
|
11
|
+
|
|
12
|
+
def beam_search(
|
|
13
|
+
search_space,
|
|
14
|
+
*,
|
|
15
|
+
masks,
|
|
16
|
+
beam_masks,
|
|
17
|
+
bitmaps,
|
|
18
|
+
beam_limit,
|
|
19
|
+
previous_beam=None,
|
|
20
|
+
):
|
|
21
|
+
"""Perform the beam search on the search space.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
search_space (list): A list of formulas.
|
|
25
|
+
masks (dict): A dictionary of concept masks. Each mask is a tensor of
|
|
26
|
+
shape (N, H, W).
|
|
27
|
+
beam_masks (dict): A dictionary of cached formula masks for formulas in
|
|
28
|
+
the current beam.
|
|
29
|
+
bitmaps (torch.Tensor): A tensor of shape (N, H, W) where N is the
|
|
30
|
+
number of samples.
|
|
31
|
+
beam_limit (int): The beam size.
|
|
32
|
+
previous_beam (dict): A dictionary mapping beam formulas to their IoU.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
current_beam_formulas (list): A list of formulas.
|
|
36
|
+
current_beam_iou (list): A list of IoU values corresponding to
|
|
37
|
+
`current_beam_formulas`.
|
|
38
|
+
visited_indices (int): The number of formulas whose exact IoU was evaluated.
|
|
39
|
+
"""
|
|
40
|
+
if previous_beam is None:
|
|
41
|
+
previous_beam = {}
|
|
42
|
+
current_beam = Q.PriorityQueue(beam_limit)
|
|
43
|
+
current_beam_iou = []
|
|
44
|
+
current_beam_formulas = []
|
|
45
|
+
minimum = 0
|
|
46
|
+
visited_indices = 0
|
|
47
|
+
best_formula = None
|
|
48
|
+
# Init beam with previous best
|
|
49
|
+
for label, iou in previous_beam.items():
|
|
50
|
+
if not current_beam.full():
|
|
51
|
+
current_beam.put((iou, label))
|
|
52
|
+
minimum = current_beam.queue[0][0]
|
|
53
|
+
elif iou > minimum:
|
|
54
|
+
current_beam.get()
|
|
55
|
+
current_beam.put((iou, label))
|
|
56
|
+
minimum = current_beam.queue[0][0]
|
|
57
|
+
|
|
58
|
+
# Set minimum
|
|
59
|
+
if current_beam.empty():
|
|
60
|
+
minimum = 0
|
|
61
|
+
else:
|
|
62
|
+
minimum = current_beam.queue[0][0]
|
|
63
|
+
|
|
64
|
+
# Iterate over the search space
|
|
65
|
+
for candidate_formula in search_space:
|
|
66
|
+
e_iou = candidate_formula.iou
|
|
67
|
+
|
|
68
|
+
# If the estimated IoU is less than the minimum, we can stop the search if the beam is full
|
|
69
|
+
if current_beam.full() and e_iou < minimum:
|
|
70
|
+
break
|
|
71
|
+
|
|
72
|
+
# skip equivalent formulas of the current beam
|
|
73
|
+
if best_formula and hash(candidate_formula) == hash(best_formula):
|
|
74
|
+
continue
|
|
75
|
+
|
|
76
|
+
# Compute IoU
|
|
77
|
+
masks_formula = mask_utils.get_formula_mask(
|
|
78
|
+
candidate_formula, masks, beam_masks, device=bitmaps.device
|
|
79
|
+
)
|
|
80
|
+
iou = metrics.iou(masks_formula, bitmaps)
|
|
81
|
+
|
|
82
|
+
# Update visited nodes
|
|
83
|
+
visited_indices += 1
|
|
84
|
+
|
|
85
|
+
# Update beam
|
|
86
|
+
if not current_beam.full():
|
|
87
|
+
candidate_formula.iou = iou
|
|
88
|
+
current_beam.put((iou, candidate_formula))
|
|
89
|
+
minimum = current_beam.queue[0][0]
|
|
90
|
+
elif iou > minimum:
|
|
91
|
+
candidate_formula.iou = iou
|
|
92
|
+
current_beam.get()
|
|
93
|
+
current_beam.put((iou, candidate_formula))
|
|
94
|
+
minimum = current_beam.queue[0][0]
|
|
95
|
+
|
|
96
|
+
# Extract formulas and iou from the beam
|
|
97
|
+
for _ in range(current_beam.qsize()):
|
|
98
|
+
iou, candidate = current_beam.get()
|
|
99
|
+
current_beam_formulas.append(candidate)
|
|
100
|
+
current_beam_iou.append(iou)
|
|
101
|
+
return current_beam_formulas, current_beam_iou, visited_indices
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def compute_next_search_space(formulas, candidate_labels):
|
|
105
|
+
"""Compute the next search space starting from the current beam
|
|
106
|
+
of formulas.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
formulas (list): A list of formulas.
|
|
110
|
+
candidate_labels (list): A list of candidate labels.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
search_space (list): A list of formulas.
|
|
114
|
+
"""
|
|
115
|
+
search_space = []
|
|
116
|
+
|
|
117
|
+
for formula in formulas:
|
|
118
|
+
vals_formula = set(formula.get_vals())
|
|
119
|
+
for candidate_term in candidate_labels:
|
|
120
|
+
# remove dummy cases with void masks or equivalent formulas
|
|
121
|
+
if candidate_term.val in vals_formula:
|
|
122
|
+
continue
|
|
123
|
+
for op, negate in [(F.Or, False), (F.And, False), (F.And, True)]:
|
|
124
|
+
candidate_to_attach = candidate_term
|
|
125
|
+
if negate:
|
|
126
|
+
candidate_to_attach = F.Not(candidate_to_attach)
|
|
127
|
+
candidate_formula = op(formula, candidate_to_attach)
|
|
128
|
+
candidate_formula.iou = 1.0
|
|
129
|
+
|
|
130
|
+
search_space.append(candidate_formula)
|
|
131
|
+
search_space = list(set(search_space))
|
|
132
|
+
return search_space
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def compute_beam_quantities(beam, masks, beam_masks, heuristic_info, bitmaps):
|
|
136
|
+
"""Computes the quantities for each formula in the beam.
|
|
137
|
+
Args:
|
|
138
|
+
beam (dict): A dictionary containing the formulas in the beam and their corresponding iou.
|
|
139
|
+
masks (list): A list of concept masks. Each mask is a tensor of shape (H*W).
|
|
140
|
+
beam_masks (dict): A dictionary containing the masks for each formula in the beam.
|
|
141
|
+
heuristic_info (tuple): A tuple containing the masks_info, neuron_quantities, and concept quantities for the optimal heuristic.
|
|
142
|
+
bitmaps (torch.Tensor): A tensor of shape (N, H*W) where N is the number of sample.
|
|
143
|
+
Returns:
|
|
144
|
+
tuple: A tuple containing the beam info and the updated beam masks.
|
|
145
|
+
"""
|
|
146
|
+
seg_quantities, neuron_quantities, _ = heuristic_info
|
|
147
|
+
(neuron_unique, _), (neuron_common, _), _, _, _, _ = neuron_quantities
|
|
148
|
+
common_elements, unique_elements, _ = seg_quantities
|
|
149
|
+
|
|
150
|
+
# Elements re-used by all the nodes, we pre-move them to gpu if available
|
|
151
|
+
unique_elements = unique_elements.to(bitmaps.device)
|
|
152
|
+
common_elements = common_elements.to(bitmaps.device)
|
|
153
|
+
beam_info = {}
|
|
154
|
+
for label in beam:
|
|
155
|
+
if label in beam_masks or isinstance(label, F.Leaf):
|
|
156
|
+
# We already have the information for this label, we skip it
|
|
157
|
+
continue
|
|
158
|
+
|
|
159
|
+
# Compute the mask for the label
|
|
160
|
+
label_mask = mask_utils.get_formula_mask(label, masks, beam_masks, device=bitmaps.device)
|
|
161
|
+
|
|
162
|
+
# Compute the quantities from the mask
|
|
163
|
+
label_quantities = optimal_utils.compute_quantities_vector(
|
|
164
|
+
label_mask=label_mask,
|
|
165
|
+
bitmaps=bitmaps,
|
|
166
|
+
common_elements=common_elements,
|
|
167
|
+
unique_elements=unique_elements,
|
|
168
|
+
neuron_common=neuron_common,
|
|
169
|
+
neuron_unique=neuron_unique,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# Compute the info from the quantities
|
|
173
|
+
label_info = optimal_utils.get_concept_info(label_quantities)
|
|
174
|
+
|
|
175
|
+
beam_info[label] = label_info
|
|
176
|
+
label_mask = label_mask.cpu()
|
|
177
|
+
beam_masks[label] = label_mask
|
|
178
|
+
return beam_info, beam_masks
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def get_beam_info(
|
|
182
|
+
*, beam, masks, beam_masks, heuristic_info, label_mapping, bitmaps, length
|
|
183
|
+
):
|
|
184
|
+
"""Gets the information for each formula in the beam and updates the heuristic information and the label mapping.
|
|
185
|
+
Args:
|
|
186
|
+
beam (dict): A dictionary containing the formulas in the beam and their corresponding iou.
|
|
187
|
+
masks (list): A list of concept masks. Each mask is a tensor of shape (H*W).
|
|
188
|
+
beam_masks (dict): A dictionary containing the masks for each formula in the beam.
|
|
189
|
+
heuristic_info (tuple): A tuple containing the masks_info, neuron_quantities, and concept quantities for the optimal heuristic.
|
|
190
|
+
label_mapping (dict): A dictionary mapping labels to the indices of their corresponding masks.
|
|
191
|
+
bitmaps (torch.Tensor): A tensor of shape (N, H*W) where N is the number of sample.
|
|
192
|
+
length (int): The maximum length of the formulas to search.
|
|
193
|
+
Returns:
|
|
194
|
+
tuple: A tuple containing the updated beam masks and the updated heuristic information and label mapping.
|
|
195
|
+
"""
|
|
196
|
+
beam_info, beam_masks = compute_beam_quantities(
|
|
197
|
+
beam, masks, beam_masks, heuristic_info, bitmaps
|
|
198
|
+
)
|
|
199
|
+
new_heuristic_info, new_label_mapping = optimal_utils.update_heuristic_info(
|
|
200
|
+
nodes_info=beam_info,
|
|
201
|
+
heuristic_info=heuristic_info,
|
|
202
|
+
label_mapping=label_mapping,
|
|
203
|
+
max_length=length,
|
|
204
|
+
)
|
|
205
|
+
return beam_masks, (new_heuristic_info, new_label_mapping)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def sort_search_space_by(
|
|
209
|
+
*,
|
|
210
|
+
search_space,
|
|
211
|
+
label_mapping,
|
|
212
|
+
heuristic_info,
|
|
213
|
+
disjoint_info,
|
|
214
|
+
num_hits,
|
|
215
|
+
max_size_mask
|
|
216
|
+
):
|
|
217
|
+
"""Sorts the search space based on the estimated iou for each formula in the search space.
|
|
218
|
+
Args:
|
|
219
|
+
search_space (list): A list of formulas to sort.
|
|
220
|
+
label_mapping (dict): A dictionary mapping labels to the indices of their corresponding masks.
|
|
221
|
+
heuristic_info (tuple): A tuple containing the masks_info, neuron_quantities, and concept quantities for the optimal heuristic.
|
|
222
|
+
disjoint_info (dict): A dictionary containing the disjoint information for the concepts.
|
|
223
|
+
num_hits (int): The number of hits in the bitmaps.
|
|
224
|
+
max_size_mask (int): The maximum size of the mask.
|
|
225
|
+
Returns:
|
|
226
|
+
list: A sorted list of formulas based on the estimated iou.
|
|
227
|
+
"""
|
|
228
|
+
_, neuron_quantities, _ = heuristic_info
|
|
229
|
+
for index_formula, candidate_formula in enumerate(search_space):
|
|
230
|
+
label_quantities = optimal_utils.estimate_label_quantities(
|
|
231
|
+
heuristic=optimal_sample_heuristic,
|
|
232
|
+
label=candidate_formula,
|
|
233
|
+
label_mapping=label_mapping,
|
|
234
|
+
heuristic_info=heuristic_info,
|
|
235
|
+
max_size_mask=max_size_mask,
|
|
236
|
+
disjoint_info=disjoint_info,
|
|
237
|
+
)
|
|
238
|
+
if label_quantities is None:
|
|
239
|
+
# Label discarded at the previous step
|
|
240
|
+
esti_iou = 0.0
|
|
241
|
+
else:
|
|
242
|
+
|
|
243
|
+
esti_iou = optimal_utils.compute_max_iou_from_label_info(
|
|
244
|
+
label_quantities=label_quantities,
|
|
245
|
+
num_hits=num_hits,
|
|
246
|
+
neuron_quantities=neuron_quantities,
|
|
247
|
+
)
|
|
248
|
+
search_space[index_formula].iou = esti_iou
|
|
249
|
+
|
|
250
|
+
# Sort the search space based on the estimated iou in descending order
|
|
251
|
+
search_space = sorted(search_space, key=lambda x: x.iou, reverse=True)
|
|
252
|
+
return search_space
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def perform_search(
|
|
256
|
+
*,
|
|
257
|
+
masks,
|
|
258
|
+
bitmaps,
|
|
259
|
+
heuristic_info,
|
|
260
|
+
disjoint_info,
|
|
261
|
+
max_size_mask,
|
|
262
|
+
beam_size=5,
|
|
263
|
+
length=3
|
|
264
|
+
):
|
|
265
|
+
"""Performs the beam optimal search to find the best formula that explains the bitmaps given the masks and the heuristic information.
|
|
266
|
+
Args:
|
|
267
|
+
masks (list): A list of concept masks. Each mask is a tensor of shape (H*W).
|
|
268
|
+
bitmaps (torch.Tensor): A tensor of shape (N, H*W) where N is the number of sample.
|
|
269
|
+
heuristic_info (tuple): A tuple containing the masks_info, neuron_quantities, and concept quantities for the optimal heuristic.
|
|
270
|
+
disjoint_info (dict): A dictionary containing the disjoint information for the concepts.
|
|
271
|
+
max_size_mask (int): The maximum size of the mask.
|
|
272
|
+
beam_size (int): The beam size for the search.
|
|
273
|
+
length (int): The maximum length of the formulas to search.
|
|
274
|
+
Returns:
|
|
275
|
+
tuple: A tuple containing the best formula, its iou, the total number of visited nodes, the number of expanded nodes, and the number of estimated nodes.
|
|
276
|
+
"""
|
|
277
|
+
|
|
278
|
+
# Number of hits
|
|
279
|
+
num_hits = bitmaps.sum()
|
|
280
|
+
|
|
281
|
+
# Utilities
|
|
282
|
+
label_mapping = {}
|
|
283
|
+
leaf_mapping = {F.Leaf(c): c for c in range(len(masks))}
|
|
284
|
+
label_mapping.update(leaf_mapping)
|
|
285
|
+
|
|
286
|
+
# Extract first beam and candidate concepts
|
|
287
|
+
candidate_labels = [F.Leaf(c) for c in range(len(masks))]
|
|
288
|
+
|
|
289
|
+
_, neuron_quantities, concept_quantities = heuristic_info
|
|
290
|
+
iou_atoms = {
|
|
291
|
+
k: optimal_utils.compute_max_iou_from_label_info(
|
|
292
|
+
concept_quantities[label_mapping[k]], neuron_quantities, num_hits
|
|
293
|
+
)
|
|
294
|
+
for k in candidate_labels
|
|
295
|
+
}
|
|
296
|
+
iou_atoms = Counter(iou_atoms)
|
|
297
|
+
first_beam_num = min(len(iou_atoms), beam_size * 2)
|
|
298
|
+
beam_atoms = {
|
|
299
|
+
lab: iou for lab, iou in iou_atoms.most_common(first_beam_num) if iou > 0
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
if len(beam_atoms) == 0:
|
|
303
|
+
return None
|
|
304
|
+
|
|
305
|
+
# Beam Search
|
|
306
|
+
beam_masks = {}
|
|
307
|
+
beam = beam_atoms.copy()
|
|
308
|
+
for previous_beam_length in range(1, length):
|
|
309
|
+
# Only expand formulas of the previous beam length to avoid regenerating beam tree already explored
|
|
310
|
+
to_expand = [lab for lab in beam.keys() if len(lab) == previous_beam_length]
|
|
311
|
+
search_space = compute_next_search_space(
|
|
312
|
+
to_expand,
|
|
313
|
+
candidate_labels,
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
# Compute the estimation for the next frontier
|
|
317
|
+
sorted_search_space = sort_search_space_by(
|
|
318
|
+
search_space=search_space,
|
|
319
|
+
label_mapping=label_mapping,
|
|
320
|
+
heuristic_info=heuristic_info,
|
|
321
|
+
num_hits=num_hits,
|
|
322
|
+
max_size_mask=max_size_mask,
|
|
323
|
+
disjoint_info=disjoint_info,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
# # If we are in the last iteration, we set the beam size to 1 to get only the best formula
|
|
327
|
+
if previous_beam_length == length - 1:
|
|
328
|
+
beam_size = 1
|
|
329
|
+
|
|
330
|
+
# Perform beam search
|
|
331
|
+
next_beam_formulas, next_beam_iou, beam_visited = beam_search(
|
|
332
|
+
sorted_search_space,
|
|
333
|
+
masks=masks,
|
|
334
|
+
previous_beam=beam,
|
|
335
|
+
beam_masks=beam_masks,
|
|
336
|
+
bitmaps=bitmaps,
|
|
337
|
+
beam_limit=beam_size,
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
# Update top formulas
|
|
341
|
+
for index_beam in range(len(next_beam_formulas)):
|
|
342
|
+
beam.update({next_beam_formulas[index_beam]: next_beam_iou[index_beam]})
|
|
343
|
+
|
|
344
|
+
# Trim the beam
|
|
345
|
+
beam = dict(Counter(beam).most_common(beam_size))
|
|
346
|
+
|
|
347
|
+
# Update infos if there are step left
|
|
348
|
+
if previous_beam_length < length - 1:
|
|
349
|
+
beam_masks, (heuristic_info, label_mapping) = get_beam_info(
|
|
350
|
+
beam=beam.keys(),
|
|
351
|
+
masks=masks,
|
|
352
|
+
beam_masks=beam_masks,
|
|
353
|
+
heuristic_info=heuristic_info,
|
|
354
|
+
label_mapping=label_mapping,
|
|
355
|
+
bitmaps=bitmaps,
|
|
356
|
+
length=length,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
top_result = Counter(beam).most_common(1)[0]
|
|
360
|
+
|
|
361
|
+
best_label = top_result[0]
|
|
362
|
+
return best_label
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def compute_beam_explanations(
|
|
366
|
+
*, bitmaps, masks, disjoint_info, length, beam_size, device, cache_dir=None
|
|
367
|
+
):
|
|
368
|
+
"""Computes the beam explanations for the given bitmaps and masks.
|
|
369
|
+
Args:
|
|
370
|
+
bitmaps (torch.Tensor): A tensor of shape (N, H*W) where N is the number of sample.
|
|
371
|
+
masks (list): A list of concept masks. Each mask is a tensor of shape (H*W).
|
|
372
|
+
disjoint_info (dict): A dictionary containing the disjoint information for the concepts.
|
|
373
|
+
length (int): The maximum length of the formulas to search.
|
|
374
|
+
device (torch.device): The device to perform the computations on.
|
|
375
|
+
cache_dir (str, optional): The directory to store the cached information for the masks. If None, the information is not cached and is recomputed every time. Defaults to None.
|
|
376
|
+
Returns:
|
|
377
|
+
tuple: A tuple containing the best formula, its iou, the total number of visited nodes, the number of expanded nodes, and the number of estimated nodes.
|
|
378
|
+
"""
|
|
379
|
+
|
|
380
|
+
max_size_mask = torch.numel(masks[0])
|
|
381
|
+
bitmaps = bitmaps.to(device)
|
|
382
|
+
|
|
383
|
+
# Get masks info
|
|
384
|
+
masks_info = mask_utils.get_dataset_info(masks, cache_dir)
|
|
385
|
+
|
|
386
|
+
# Get quantities
|
|
387
|
+
neuron_quantities, concept_quantities = optimal_utils.get_optimal_heuristic_info(
|
|
388
|
+
masks=masks, bitmaps=bitmaps, masks_quantities=masks_info
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
heuristic_info = (
|
|
392
|
+
masks_info,
|
|
393
|
+
neuron_quantities,
|
|
394
|
+
concept_quantities,
|
|
395
|
+
)
|
|
396
|
+
best_label = perform_search(
|
|
397
|
+
masks=masks,
|
|
398
|
+
bitmaps=bitmaps,
|
|
399
|
+
heuristic_info=heuristic_info,
|
|
400
|
+
disjoint_info=disjoint_info,
|
|
401
|
+
max_size_mask=max_size_mask,
|
|
402
|
+
beam_size=beam_size,
|
|
403
|
+
length=length,
|
|
404
|
+
)
|
|
405
|
+
return best_label
|