compositional-explanations 0.0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- compositional_explanations/__init__.py +0 -0
- compositional_explanations/beam.py +405 -0
- compositional_explanations/formula.py +418 -0
- compositional_explanations/metrics.py +65 -0
- compositional_explanations/optimal.py +947 -0
- compositional_explanations/optimal_sample_heuristic.py +776 -0
- compositional_explanations/optimal_sum_heuristic.py +635 -0
- compositional_explanations/path_heuristic.py +675 -0
- compositional_explanations-0.0.0.1.dist-info/LICENSE +19 -0
- compositional_explanations-0.0.0.1.dist-info/METADATA +86 -0
- compositional_explanations-0.0.0.1.dist-info/RECORD +17 -0
- compositional_explanations-0.0.0.1.dist-info/WHEEL +5 -0
- compositional_explanations-0.0.0.1.dist-info/top_level.txt +2 -0
- utils/constants.py +30 -0
- utils/general_utils.py +28 -0
- utils/mask_utils.py +125 -0
- utils/optimal_utils.py +590 -0
|
File without changes
|
|
@@ -0,0 +1,405 @@
|
|
|
1
|
+
"""Beam-search variant guided by the optimal heuristic family."""
|
|
2
|
+
|
|
3
|
+
from collections import Counter
|
|
4
|
+
import queue as Q
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from compositional_explanations import formula as F
|
|
9
|
+
from compositional_explanations import optimal_sample_heuristic, metrics
|
|
10
|
+
from utils import optimal_utils, mask_utils
|
|
11
|
+
|
|
12
|
+
def beam_search(
|
|
13
|
+
search_space,
|
|
14
|
+
*,
|
|
15
|
+
masks,
|
|
16
|
+
beam_masks,
|
|
17
|
+
bitmaps,
|
|
18
|
+
beam_limit,
|
|
19
|
+
previous_beam=None,
|
|
20
|
+
):
|
|
21
|
+
"""Perform the beam search on the search space.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
search_space (list): A list of formulas.
|
|
25
|
+
masks (dict): A dictionary of concept masks. Each mask is a tensor of
|
|
26
|
+
shape (N, H, W).
|
|
27
|
+
beam_masks (dict): A dictionary of cached formula masks for formulas in
|
|
28
|
+
the current beam.
|
|
29
|
+
bitmaps (torch.Tensor): A tensor of shape (N, H, W) where N is the
|
|
30
|
+
number of samples.
|
|
31
|
+
beam_limit (int): The beam size.
|
|
32
|
+
previous_beam (dict): A dictionary mapping beam formulas to their IoU.
|
|
33
|
+
|
|
34
|
+
Returns:
|
|
35
|
+
current_beam_formulas (list): A list of formulas.
|
|
36
|
+
current_beam_iou (list): A list of IoU values corresponding to
|
|
37
|
+
`current_beam_formulas`.
|
|
38
|
+
visited_indices (int): The number of formulas whose exact IoU was evaluated.
|
|
39
|
+
"""
|
|
40
|
+
if previous_beam is None:
|
|
41
|
+
previous_beam = {}
|
|
42
|
+
current_beam = Q.PriorityQueue(beam_limit)
|
|
43
|
+
current_beam_iou = []
|
|
44
|
+
current_beam_formulas = []
|
|
45
|
+
minimum = 0
|
|
46
|
+
visited_indices = 0
|
|
47
|
+
best_formula = None
|
|
48
|
+
# Init beam with previous best
|
|
49
|
+
for label, iou in previous_beam.items():
|
|
50
|
+
if not current_beam.full():
|
|
51
|
+
current_beam.put((iou, label))
|
|
52
|
+
minimum = current_beam.queue[0][0]
|
|
53
|
+
elif iou > minimum:
|
|
54
|
+
current_beam.get()
|
|
55
|
+
current_beam.put((iou, label))
|
|
56
|
+
minimum = current_beam.queue[0][0]
|
|
57
|
+
|
|
58
|
+
# Set minimum
|
|
59
|
+
if current_beam.empty():
|
|
60
|
+
minimum = 0
|
|
61
|
+
else:
|
|
62
|
+
minimum = current_beam.queue[0][0]
|
|
63
|
+
|
|
64
|
+
# Iterate over the search space
|
|
65
|
+
for candidate_formula in search_space:
|
|
66
|
+
e_iou = candidate_formula.iou
|
|
67
|
+
|
|
68
|
+
# If the estimated IoU is less than the minimum, we can stop the search if the beam is full
|
|
69
|
+
if current_beam.full() and e_iou < minimum:
|
|
70
|
+
break
|
|
71
|
+
|
|
72
|
+
# skip equivalent formulas of the current beam
|
|
73
|
+
if best_formula and hash(candidate_formula) == hash(best_formula):
|
|
74
|
+
continue
|
|
75
|
+
|
|
76
|
+
# Compute IoU
|
|
77
|
+
masks_formula = mask_utils.get_formula_mask(
|
|
78
|
+
candidate_formula, masks, beam_masks, device=bitmaps.device
|
|
79
|
+
)
|
|
80
|
+
iou = metrics.iou(masks_formula, bitmaps)
|
|
81
|
+
|
|
82
|
+
# Update visited nodes
|
|
83
|
+
visited_indices += 1
|
|
84
|
+
|
|
85
|
+
# Update beam
|
|
86
|
+
if not current_beam.full():
|
|
87
|
+
candidate_formula.iou = iou
|
|
88
|
+
current_beam.put((iou, candidate_formula))
|
|
89
|
+
minimum = current_beam.queue[0][0]
|
|
90
|
+
elif iou > minimum:
|
|
91
|
+
candidate_formula.iou = iou
|
|
92
|
+
current_beam.get()
|
|
93
|
+
current_beam.put((iou, candidate_formula))
|
|
94
|
+
minimum = current_beam.queue[0][0]
|
|
95
|
+
|
|
96
|
+
# Extract formulas and iou from the beam
|
|
97
|
+
for _ in range(current_beam.qsize()):
|
|
98
|
+
iou, candidate = current_beam.get()
|
|
99
|
+
current_beam_formulas.append(candidate)
|
|
100
|
+
current_beam_iou.append(iou)
|
|
101
|
+
return current_beam_formulas, current_beam_iou, visited_indices
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def compute_next_search_space(formulas, candidate_labels):
|
|
105
|
+
"""Compute the next search space starting from the current beam
|
|
106
|
+
of formulas.
|
|
107
|
+
|
|
108
|
+
Args:
|
|
109
|
+
formulas (list): A list of formulas.
|
|
110
|
+
candidate_labels (list): A list of candidate labels.
|
|
111
|
+
|
|
112
|
+
Returns:
|
|
113
|
+
search_space (list): A list of formulas.
|
|
114
|
+
"""
|
|
115
|
+
search_space = []
|
|
116
|
+
|
|
117
|
+
for formula in formulas:
|
|
118
|
+
vals_formula = set(formula.get_vals())
|
|
119
|
+
for candidate_term in candidate_labels:
|
|
120
|
+
# remove dummy cases with void masks or equivalent formulas
|
|
121
|
+
if candidate_term.val in vals_formula:
|
|
122
|
+
continue
|
|
123
|
+
for op, negate in [(F.Or, False), (F.And, False), (F.And, True)]:
|
|
124
|
+
candidate_to_attach = candidate_term
|
|
125
|
+
if negate:
|
|
126
|
+
candidate_to_attach = F.Not(candidate_to_attach)
|
|
127
|
+
candidate_formula = op(formula, candidate_to_attach)
|
|
128
|
+
candidate_formula.iou = 1.0
|
|
129
|
+
|
|
130
|
+
search_space.append(candidate_formula)
|
|
131
|
+
search_space = list(set(search_space))
|
|
132
|
+
return search_space
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def compute_beam_quantities(beam, masks, beam_masks, heuristic_info, bitmaps):
|
|
136
|
+
"""Computes the quantities for each formula in the beam.
|
|
137
|
+
Args:
|
|
138
|
+
beam (dict): A dictionary containing the formulas in the beam and their corresponding iou.
|
|
139
|
+
masks (list): A list of concept masks. Each mask is a tensor of shape (H*W).
|
|
140
|
+
beam_masks (dict): A dictionary containing the masks for each formula in the beam.
|
|
141
|
+
heuristic_info (tuple): A tuple containing the masks_info, neuron_quantities, and concept quantities for the optimal heuristic.
|
|
142
|
+
bitmaps (torch.Tensor): A tensor of shape (N, H*W) where N is the number of sample.
|
|
143
|
+
Returns:
|
|
144
|
+
tuple: A tuple containing the beam info and the updated beam masks.
|
|
145
|
+
"""
|
|
146
|
+
seg_quantities, neuron_quantities, _ = heuristic_info
|
|
147
|
+
(neuron_unique, _), (neuron_common, _), _, _, _, _ = neuron_quantities
|
|
148
|
+
common_elements, unique_elements, _ = seg_quantities
|
|
149
|
+
|
|
150
|
+
# Elements re-used by all the nodes, we pre-move them to gpu if available
|
|
151
|
+
unique_elements = unique_elements.to(bitmaps.device)
|
|
152
|
+
common_elements = common_elements.to(bitmaps.device)
|
|
153
|
+
beam_info = {}
|
|
154
|
+
for label in beam:
|
|
155
|
+
if label in beam_masks or isinstance(label, F.Leaf):
|
|
156
|
+
# We already have the information for this label, we skip it
|
|
157
|
+
continue
|
|
158
|
+
|
|
159
|
+
# Compute the mask for the label
|
|
160
|
+
label_mask = mask_utils.get_formula_mask(label, masks, beam_masks, device=bitmaps.device)
|
|
161
|
+
|
|
162
|
+
# Compute the quantities from the mask
|
|
163
|
+
label_quantities = optimal_utils.compute_quantities_vector(
|
|
164
|
+
label_mask=label_mask,
|
|
165
|
+
bitmaps=bitmaps,
|
|
166
|
+
common_elements=common_elements,
|
|
167
|
+
unique_elements=unique_elements,
|
|
168
|
+
neuron_common=neuron_common,
|
|
169
|
+
neuron_unique=neuron_unique,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
# Compute the info from the quantities
|
|
173
|
+
label_info = optimal_utils.get_concept_info(label_quantities)
|
|
174
|
+
|
|
175
|
+
beam_info[label] = label_info
|
|
176
|
+
label_mask = label_mask.cpu()
|
|
177
|
+
beam_masks[label] = label_mask
|
|
178
|
+
return beam_info, beam_masks
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def get_beam_info(
|
|
182
|
+
*, beam, masks, beam_masks, heuristic_info, label_mapping, bitmaps, length
|
|
183
|
+
):
|
|
184
|
+
"""Gets the information for each formula in the beam and updates the heuristic information and the label mapping.
|
|
185
|
+
Args:
|
|
186
|
+
beam (dict): A dictionary containing the formulas in the beam and their corresponding iou.
|
|
187
|
+
masks (list): A list of concept masks. Each mask is a tensor of shape (H*W).
|
|
188
|
+
beam_masks (dict): A dictionary containing the masks for each formula in the beam.
|
|
189
|
+
heuristic_info (tuple): A tuple containing the masks_info, neuron_quantities, and concept quantities for the optimal heuristic.
|
|
190
|
+
label_mapping (dict): A dictionary mapping labels to the indices of their corresponding masks.
|
|
191
|
+
bitmaps (torch.Tensor): A tensor of shape (N, H*W) where N is the number of sample.
|
|
192
|
+
length (int): The maximum length of the formulas to search.
|
|
193
|
+
Returns:
|
|
194
|
+
tuple: A tuple containing the updated beam masks and the updated heuristic information and label mapping.
|
|
195
|
+
"""
|
|
196
|
+
beam_info, beam_masks = compute_beam_quantities(
|
|
197
|
+
beam, masks, beam_masks, heuristic_info, bitmaps
|
|
198
|
+
)
|
|
199
|
+
new_heuristic_info, new_label_mapping = optimal_utils.update_heuristic_info(
|
|
200
|
+
nodes_info=beam_info,
|
|
201
|
+
heuristic_info=heuristic_info,
|
|
202
|
+
label_mapping=label_mapping,
|
|
203
|
+
max_length=length,
|
|
204
|
+
)
|
|
205
|
+
return beam_masks, (new_heuristic_info, new_label_mapping)
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
def sort_search_space_by(
|
|
209
|
+
*,
|
|
210
|
+
search_space,
|
|
211
|
+
label_mapping,
|
|
212
|
+
heuristic_info,
|
|
213
|
+
disjoint_info,
|
|
214
|
+
num_hits,
|
|
215
|
+
max_size_mask
|
|
216
|
+
):
|
|
217
|
+
"""Sorts the search space based on the estimated iou for each formula in the search space.
|
|
218
|
+
Args:
|
|
219
|
+
search_space (list): A list of formulas to sort.
|
|
220
|
+
label_mapping (dict): A dictionary mapping labels to the indices of their corresponding masks.
|
|
221
|
+
heuristic_info (tuple): A tuple containing the masks_info, neuron_quantities, and concept quantities for the optimal heuristic.
|
|
222
|
+
disjoint_info (dict): A dictionary containing the disjoint information for the concepts.
|
|
223
|
+
num_hits (int): The number of hits in the bitmaps.
|
|
224
|
+
max_size_mask (int): The maximum size of the mask.
|
|
225
|
+
Returns:
|
|
226
|
+
list: A sorted list of formulas based on the estimated iou.
|
|
227
|
+
"""
|
|
228
|
+
_, neuron_quantities, _ = heuristic_info
|
|
229
|
+
for index_formula, candidate_formula in enumerate(search_space):
|
|
230
|
+
label_quantities = optimal_utils.estimate_label_quantities(
|
|
231
|
+
heuristic=optimal_sample_heuristic,
|
|
232
|
+
label=candidate_formula,
|
|
233
|
+
label_mapping=label_mapping,
|
|
234
|
+
heuristic_info=heuristic_info,
|
|
235
|
+
max_size_mask=max_size_mask,
|
|
236
|
+
disjoint_info=disjoint_info,
|
|
237
|
+
)
|
|
238
|
+
if label_quantities is None:
|
|
239
|
+
# Label discarded at the previous step
|
|
240
|
+
esti_iou = 0.0
|
|
241
|
+
else:
|
|
242
|
+
|
|
243
|
+
esti_iou = optimal_utils.compute_max_iou_from_label_info(
|
|
244
|
+
label_quantities=label_quantities,
|
|
245
|
+
num_hits=num_hits,
|
|
246
|
+
neuron_quantities=neuron_quantities,
|
|
247
|
+
)
|
|
248
|
+
search_space[index_formula].iou = esti_iou
|
|
249
|
+
|
|
250
|
+
# Sort the search space based on the estimated iou in descending order
|
|
251
|
+
search_space = sorted(search_space, key=lambda x: x.iou, reverse=True)
|
|
252
|
+
return search_space
|
|
253
|
+
|
|
254
|
+
|
|
255
|
+
def perform_search(
|
|
256
|
+
*,
|
|
257
|
+
masks,
|
|
258
|
+
bitmaps,
|
|
259
|
+
heuristic_info,
|
|
260
|
+
disjoint_info,
|
|
261
|
+
max_size_mask,
|
|
262
|
+
beam_size=5,
|
|
263
|
+
length=3
|
|
264
|
+
):
|
|
265
|
+
"""Performs the beam optimal search to find the best formula that explains the bitmaps given the masks and the heuristic information.
|
|
266
|
+
Args:
|
|
267
|
+
masks (list): A list of concept masks. Each mask is a tensor of shape (H*W).
|
|
268
|
+
bitmaps (torch.Tensor): A tensor of shape (N, H*W) where N is the number of sample.
|
|
269
|
+
heuristic_info (tuple): A tuple containing the masks_info, neuron_quantities, and concept quantities for the optimal heuristic.
|
|
270
|
+
disjoint_info (dict): A dictionary containing the disjoint information for the concepts.
|
|
271
|
+
max_size_mask (int): The maximum size of the mask.
|
|
272
|
+
beam_size (int): The beam size for the search.
|
|
273
|
+
length (int): The maximum length of the formulas to search.
|
|
274
|
+
Returns:
|
|
275
|
+
tuple: A tuple containing the best formula, its iou, the total number of visited nodes, the number of expanded nodes, and the number of estimated nodes.
|
|
276
|
+
"""
|
|
277
|
+
|
|
278
|
+
# Number of hits
|
|
279
|
+
num_hits = bitmaps.sum()
|
|
280
|
+
|
|
281
|
+
# Utilities
|
|
282
|
+
label_mapping = {}
|
|
283
|
+
leaf_mapping = {F.Leaf(c): c for c in range(len(masks))}
|
|
284
|
+
label_mapping.update(leaf_mapping)
|
|
285
|
+
|
|
286
|
+
# Extract first beam and candidate concepts
|
|
287
|
+
candidate_labels = [F.Leaf(c) for c in range(len(masks))]
|
|
288
|
+
|
|
289
|
+
_, neuron_quantities, concept_quantities = heuristic_info
|
|
290
|
+
iou_atoms = {
|
|
291
|
+
k: optimal_utils.compute_max_iou_from_label_info(
|
|
292
|
+
concept_quantities[label_mapping[k]], neuron_quantities, num_hits
|
|
293
|
+
)
|
|
294
|
+
for k in candidate_labels
|
|
295
|
+
}
|
|
296
|
+
iou_atoms = Counter(iou_atoms)
|
|
297
|
+
first_beam_num = min(len(iou_atoms), beam_size * 2)
|
|
298
|
+
beam_atoms = {
|
|
299
|
+
lab: iou for lab, iou in iou_atoms.most_common(first_beam_num) if iou > 0
|
|
300
|
+
}
|
|
301
|
+
|
|
302
|
+
if len(beam_atoms) == 0:
|
|
303
|
+
return None
|
|
304
|
+
|
|
305
|
+
# Beam Search
|
|
306
|
+
beam_masks = {}
|
|
307
|
+
beam = beam_atoms.copy()
|
|
308
|
+
for previous_beam_length in range(1, length):
|
|
309
|
+
# Only expand formulas of the previous beam length to avoid regenerating beam tree already explored
|
|
310
|
+
to_expand = [lab for lab in beam.keys() if len(lab) == previous_beam_length]
|
|
311
|
+
search_space = compute_next_search_space(
|
|
312
|
+
to_expand,
|
|
313
|
+
candidate_labels,
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
# Compute the estimation for the next frontier
|
|
317
|
+
sorted_search_space = sort_search_space_by(
|
|
318
|
+
search_space=search_space,
|
|
319
|
+
label_mapping=label_mapping,
|
|
320
|
+
heuristic_info=heuristic_info,
|
|
321
|
+
num_hits=num_hits,
|
|
322
|
+
max_size_mask=max_size_mask,
|
|
323
|
+
disjoint_info=disjoint_info,
|
|
324
|
+
)
|
|
325
|
+
|
|
326
|
+
# # If we are in the last iteration, we set the beam size to 1 to get only the best formula
|
|
327
|
+
if previous_beam_length == length - 1:
|
|
328
|
+
beam_size = 1
|
|
329
|
+
|
|
330
|
+
# Perform beam search
|
|
331
|
+
next_beam_formulas, next_beam_iou, beam_visited = beam_search(
|
|
332
|
+
sorted_search_space,
|
|
333
|
+
masks=masks,
|
|
334
|
+
previous_beam=beam,
|
|
335
|
+
beam_masks=beam_masks,
|
|
336
|
+
bitmaps=bitmaps,
|
|
337
|
+
beam_limit=beam_size,
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
# Update top formulas
|
|
341
|
+
for index_beam in range(len(next_beam_formulas)):
|
|
342
|
+
beam.update({next_beam_formulas[index_beam]: next_beam_iou[index_beam]})
|
|
343
|
+
|
|
344
|
+
# Trim the beam
|
|
345
|
+
beam = dict(Counter(beam).most_common(beam_size))
|
|
346
|
+
|
|
347
|
+
# Update infos if there are step left
|
|
348
|
+
if previous_beam_length < length - 1:
|
|
349
|
+
beam_masks, (heuristic_info, label_mapping) = get_beam_info(
|
|
350
|
+
beam=beam.keys(),
|
|
351
|
+
masks=masks,
|
|
352
|
+
beam_masks=beam_masks,
|
|
353
|
+
heuristic_info=heuristic_info,
|
|
354
|
+
label_mapping=label_mapping,
|
|
355
|
+
bitmaps=bitmaps,
|
|
356
|
+
length=length,
|
|
357
|
+
)
|
|
358
|
+
|
|
359
|
+
top_result = Counter(beam).most_common(1)[0]
|
|
360
|
+
|
|
361
|
+
best_label = top_result[0]
|
|
362
|
+
return best_label
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def compute_beam_explanations(
|
|
366
|
+
*, bitmaps, masks, disjoint_info, length, beam_size, device, cache_dir=None
|
|
367
|
+
):
|
|
368
|
+
"""Computes the beam explanations for the given bitmaps and masks.
|
|
369
|
+
Args:
|
|
370
|
+
bitmaps (torch.Tensor): A tensor of shape (N, H*W) where N is the number of sample.
|
|
371
|
+
masks (list): A list of concept masks. Each mask is a tensor of shape (H*W).
|
|
372
|
+
disjoint_info (dict): A dictionary containing the disjoint information for the concepts.
|
|
373
|
+
length (int): The maximum length of the formulas to search.
|
|
374
|
+
device (torch.device): The device to perform the computations on.
|
|
375
|
+
cache_dir (str, optional): The directory to store the cached information for the masks. If None, the information is not cached and is recomputed every time. Defaults to None.
|
|
376
|
+
Returns:
|
|
377
|
+
tuple: A tuple containing the best formula, its iou, the total number of visited nodes, the number of expanded nodes, and the number of estimated nodes.
|
|
378
|
+
"""
|
|
379
|
+
|
|
380
|
+
max_size_mask = torch.numel(masks[0])
|
|
381
|
+
bitmaps = bitmaps.to(device)
|
|
382
|
+
|
|
383
|
+
# Get masks info
|
|
384
|
+
masks_info = mask_utils.get_dataset_info(masks, cache_dir)
|
|
385
|
+
|
|
386
|
+
# Get quantities
|
|
387
|
+
neuron_quantities, concept_quantities = optimal_utils.get_optimal_heuristic_info(
|
|
388
|
+
masks=masks, bitmaps=bitmaps, masks_quantities=masks_info
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
heuristic_info = (
|
|
392
|
+
masks_info,
|
|
393
|
+
neuron_quantities,
|
|
394
|
+
concept_quantities,
|
|
395
|
+
)
|
|
396
|
+
best_label = perform_search(
|
|
397
|
+
masks=masks,
|
|
398
|
+
bitmaps=bitmaps,
|
|
399
|
+
heuristic_info=heuristic_info,
|
|
400
|
+
disjoint_info=disjoint_info,
|
|
401
|
+
max_size_mask=max_size_mask,
|
|
402
|
+
beam_size=beam_size,
|
|
403
|
+
length=length,
|
|
404
|
+
)
|
|
405
|
+
return best_label
|