compositional-explanations 0.0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
@@ -0,0 +1,405 @@
1
+ """Beam-search variant guided by the optimal heuristic family."""
2
+
3
+ from collections import Counter
4
+ import queue as Q
5
+
6
+ import torch
7
+
8
+ from compositional_explanations import formula as F
9
+ from compositional_explanations import optimal_sample_heuristic, metrics
10
+ from utils import optimal_utils, mask_utils
11
+
12
+ def beam_search(
13
+ search_space,
14
+ *,
15
+ masks,
16
+ beam_masks,
17
+ bitmaps,
18
+ beam_limit,
19
+ previous_beam=None,
20
+ ):
21
+ """Perform the beam search on the search space.
22
+
23
+ Args:
24
+ search_space (list): A list of formulas.
25
+ masks (dict): A dictionary of concept masks. Each mask is a tensor of
26
+ shape (N, H, W).
27
+ beam_masks (dict): A dictionary of cached formula masks for formulas in
28
+ the current beam.
29
+ bitmaps (torch.Tensor): A tensor of shape (N, H, W) where N is the
30
+ number of samples.
31
+ beam_limit (int): The beam size.
32
+ previous_beam (dict): A dictionary mapping beam formulas to their IoU.
33
+
34
+ Returns:
35
+ current_beam_formulas (list): A list of formulas.
36
+ current_beam_iou (list): A list of IoU values corresponding to
37
+ `current_beam_formulas`.
38
+ visited_indices (int): The number of formulas whose exact IoU was evaluated.
39
+ """
40
+ if previous_beam is None:
41
+ previous_beam = {}
42
+ current_beam = Q.PriorityQueue(beam_limit)
43
+ current_beam_iou = []
44
+ current_beam_formulas = []
45
+ minimum = 0
46
+ visited_indices = 0
47
+ best_formula = None
48
+ # Init beam with previous best
49
+ for label, iou in previous_beam.items():
50
+ if not current_beam.full():
51
+ current_beam.put((iou, label))
52
+ minimum = current_beam.queue[0][0]
53
+ elif iou > minimum:
54
+ current_beam.get()
55
+ current_beam.put((iou, label))
56
+ minimum = current_beam.queue[0][0]
57
+
58
+ # Set minimum
59
+ if current_beam.empty():
60
+ minimum = 0
61
+ else:
62
+ minimum = current_beam.queue[0][0]
63
+
64
+ # Iterate over the search space
65
+ for candidate_formula in search_space:
66
+ e_iou = candidate_formula.iou
67
+
68
+ # If the estimated IoU is less than the minimum, we can stop the search if the beam is full
69
+ if current_beam.full() and e_iou < minimum:
70
+ break
71
+
72
+ # skip equivalent formulas of the current beam
73
+ if best_formula and hash(candidate_formula) == hash(best_formula):
74
+ continue
75
+
76
+ # Compute IoU
77
+ masks_formula = mask_utils.get_formula_mask(
78
+ candidate_formula, masks, beam_masks, device=bitmaps.device
79
+ )
80
+ iou = metrics.iou(masks_formula, bitmaps)
81
+
82
+ # Update visited nodes
83
+ visited_indices += 1
84
+
85
+ # Update beam
86
+ if not current_beam.full():
87
+ candidate_formula.iou = iou
88
+ current_beam.put((iou, candidate_formula))
89
+ minimum = current_beam.queue[0][0]
90
+ elif iou > minimum:
91
+ candidate_formula.iou = iou
92
+ current_beam.get()
93
+ current_beam.put((iou, candidate_formula))
94
+ minimum = current_beam.queue[0][0]
95
+
96
+ # Extract formulas and iou from the beam
97
+ for _ in range(current_beam.qsize()):
98
+ iou, candidate = current_beam.get()
99
+ current_beam_formulas.append(candidate)
100
+ current_beam_iou.append(iou)
101
+ return current_beam_formulas, current_beam_iou, visited_indices
102
+
103
+
104
+ def compute_next_search_space(formulas, candidate_labels):
105
+ """Compute the next search space starting from the current beam
106
+ of formulas.
107
+
108
+ Args:
109
+ formulas (list): A list of formulas.
110
+ candidate_labels (list): A list of candidate labels.
111
+
112
+ Returns:
113
+ search_space (list): A list of formulas.
114
+ """
115
+ search_space = []
116
+
117
+ for formula in formulas:
118
+ vals_formula = set(formula.get_vals())
119
+ for candidate_term in candidate_labels:
120
+ # remove dummy cases with void masks or equivalent formulas
121
+ if candidate_term.val in vals_formula:
122
+ continue
123
+ for op, negate in [(F.Or, False), (F.And, False), (F.And, True)]:
124
+ candidate_to_attach = candidate_term
125
+ if negate:
126
+ candidate_to_attach = F.Not(candidate_to_attach)
127
+ candidate_formula = op(formula, candidate_to_attach)
128
+ candidate_formula.iou = 1.0
129
+
130
+ search_space.append(candidate_formula)
131
+ search_space = list(set(search_space))
132
+ return search_space
133
+
134
+
135
+ def compute_beam_quantities(beam, masks, beam_masks, heuristic_info, bitmaps):
136
+ """Computes the quantities for each formula in the beam.
137
+ Args:
138
+ beam (dict): A dictionary containing the formulas in the beam and their corresponding iou.
139
+ masks (list): A list of concept masks. Each mask is a tensor of shape (H*W).
140
+ beam_masks (dict): A dictionary containing the masks for each formula in the beam.
141
+ heuristic_info (tuple): A tuple containing the masks_info, neuron_quantities, and concept quantities for the optimal heuristic.
142
+ bitmaps (torch.Tensor): A tensor of shape (N, H*W) where N is the number of sample.
143
+ Returns:
144
+ tuple: A tuple containing the beam info and the updated beam masks.
145
+ """
146
+ seg_quantities, neuron_quantities, _ = heuristic_info
147
+ (neuron_unique, _), (neuron_common, _), _, _, _, _ = neuron_quantities
148
+ common_elements, unique_elements, _ = seg_quantities
149
+
150
+ # Elements re-used by all the nodes, we pre-move them to gpu if available
151
+ unique_elements = unique_elements.to(bitmaps.device)
152
+ common_elements = common_elements.to(bitmaps.device)
153
+ beam_info = {}
154
+ for label in beam:
155
+ if label in beam_masks or isinstance(label, F.Leaf):
156
+ # We already have the information for this label, we skip it
157
+ continue
158
+
159
+ # Compute the mask for the label
160
+ label_mask = mask_utils.get_formula_mask(label, masks, beam_masks, device=bitmaps.device)
161
+
162
+ # Compute the quantities from the mask
163
+ label_quantities = optimal_utils.compute_quantities_vector(
164
+ label_mask=label_mask,
165
+ bitmaps=bitmaps,
166
+ common_elements=common_elements,
167
+ unique_elements=unique_elements,
168
+ neuron_common=neuron_common,
169
+ neuron_unique=neuron_unique,
170
+ )
171
+
172
+ # Compute the info from the quantities
173
+ label_info = optimal_utils.get_concept_info(label_quantities)
174
+
175
+ beam_info[label] = label_info
176
+ label_mask = label_mask.cpu()
177
+ beam_masks[label] = label_mask
178
+ return beam_info, beam_masks
179
+
180
+
181
+ def get_beam_info(
182
+ *, beam, masks, beam_masks, heuristic_info, label_mapping, bitmaps, length
183
+ ):
184
+ """Gets the information for each formula in the beam and updates the heuristic information and the label mapping.
185
+ Args:
186
+ beam (dict): A dictionary containing the formulas in the beam and their corresponding iou.
187
+ masks (list): A list of concept masks. Each mask is a tensor of shape (H*W).
188
+ beam_masks (dict): A dictionary containing the masks for each formula in the beam.
189
+ heuristic_info (tuple): A tuple containing the masks_info, neuron_quantities, and concept quantities for the optimal heuristic.
190
+ label_mapping (dict): A dictionary mapping labels to the indices of their corresponding masks.
191
+ bitmaps (torch.Tensor): A tensor of shape (N, H*W) where N is the number of sample.
192
+ length (int): The maximum length of the formulas to search.
193
+ Returns:
194
+ tuple: A tuple containing the updated beam masks and the updated heuristic information and label mapping.
195
+ """
196
+ beam_info, beam_masks = compute_beam_quantities(
197
+ beam, masks, beam_masks, heuristic_info, bitmaps
198
+ )
199
+ new_heuristic_info, new_label_mapping = optimal_utils.update_heuristic_info(
200
+ nodes_info=beam_info,
201
+ heuristic_info=heuristic_info,
202
+ label_mapping=label_mapping,
203
+ max_length=length,
204
+ )
205
+ return beam_masks, (new_heuristic_info, new_label_mapping)
206
+
207
+
208
+ def sort_search_space_by(
209
+ *,
210
+ search_space,
211
+ label_mapping,
212
+ heuristic_info,
213
+ disjoint_info,
214
+ num_hits,
215
+ max_size_mask
216
+ ):
217
+ """Sorts the search space based on the estimated iou for each formula in the search space.
218
+ Args:
219
+ search_space (list): A list of formulas to sort.
220
+ label_mapping (dict): A dictionary mapping labels to the indices of their corresponding masks.
221
+ heuristic_info (tuple): A tuple containing the masks_info, neuron_quantities, and concept quantities for the optimal heuristic.
222
+ disjoint_info (dict): A dictionary containing the disjoint information for the concepts.
223
+ num_hits (int): The number of hits in the bitmaps.
224
+ max_size_mask (int): The maximum size of the mask.
225
+ Returns:
226
+ list: A sorted list of formulas based on the estimated iou.
227
+ """
228
+ _, neuron_quantities, _ = heuristic_info
229
+ for index_formula, candidate_formula in enumerate(search_space):
230
+ label_quantities = optimal_utils.estimate_label_quantities(
231
+ heuristic=optimal_sample_heuristic,
232
+ label=candidate_formula,
233
+ label_mapping=label_mapping,
234
+ heuristic_info=heuristic_info,
235
+ max_size_mask=max_size_mask,
236
+ disjoint_info=disjoint_info,
237
+ )
238
+ if label_quantities is None:
239
+ # Label discarded at the previous step
240
+ esti_iou = 0.0
241
+ else:
242
+
243
+ esti_iou = optimal_utils.compute_max_iou_from_label_info(
244
+ label_quantities=label_quantities,
245
+ num_hits=num_hits,
246
+ neuron_quantities=neuron_quantities,
247
+ )
248
+ search_space[index_formula].iou = esti_iou
249
+
250
+ # Sort the search space based on the estimated iou in descending order
251
+ search_space = sorted(search_space, key=lambda x: x.iou, reverse=True)
252
+ return search_space
253
+
254
+
255
+ def perform_search(
256
+ *,
257
+ masks,
258
+ bitmaps,
259
+ heuristic_info,
260
+ disjoint_info,
261
+ max_size_mask,
262
+ beam_size=5,
263
+ length=3
264
+ ):
265
+ """Performs the beam optimal search to find the best formula that explains the bitmaps given the masks and the heuristic information.
266
+ Args:
267
+ masks (list): A list of concept masks. Each mask is a tensor of shape (H*W).
268
+ bitmaps (torch.Tensor): A tensor of shape (N, H*W) where N is the number of sample.
269
+ heuristic_info (tuple): A tuple containing the masks_info, neuron_quantities, and concept quantities for the optimal heuristic.
270
+ disjoint_info (dict): A dictionary containing the disjoint information for the concepts.
271
+ max_size_mask (int): The maximum size of the mask.
272
+ beam_size (int): The beam size for the search.
273
+ length (int): The maximum length of the formulas to search.
274
+ Returns:
275
+ tuple: A tuple containing the best formula, its iou, the total number of visited nodes, the number of expanded nodes, and the number of estimated nodes.
276
+ """
277
+
278
+ # Number of hits
279
+ num_hits = bitmaps.sum()
280
+
281
+ # Utilities
282
+ label_mapping = {}
283
+ leaf_mapping = {F.Leaf(c): c for c in range(len(masks))}
284
+ label_mapping.update(leaf_mapping)
285
+
286
+ # Extract first beam and candidate concepts
287
+ candidate_labels = [F.Leaf(c) for c in range(len(masks))]
288
+
289
+ _, neuron_quantities, concept_quantities = heuristic_info
290
+ iou_atoms = {
291
+ k: optimal_utils.compute_max_iou_from_label_info(
292
+ concept_quantities[label_mapping[k]], neuron_quantities, num_hits
293
+ )
294
+ for k in candidate_labels
295
+ }
296
+ iou_atoms = Counter(iou_atoms)
297
+ first_beam_num = min(len(iou_atoms), beam_size * 2)
298
+ beam_atoms = {
299
+ lab: iou for lab, iou in iou_atoms.most_common(first_beam_num) if iou > 0
300
+ }
301
+
302
+ if len(beam_atoms) == 0:
303
+ return None
304
+
305
+ # Beam Search
306
+ beam_masks = {}
307
+ beam = beam_atoms.copy()
308
+ for previous_beam_length in range(1, length):
309
+ # Only expand formulas of the previous beam length to avoid regenerating beam tree already explored
310
+ to_expand = [lab for lab in beam.keys() if len(lab) == previous_beam_length]
311
+ search_space = compute_next_search_space(
312
+ to_expand,
313
+ candidate_labels,
314
+ )
315
+
316
+ # Compute the estimation for the next frontier
317
+ sorted_search_space = sort_search_space_by(
318
+ search_space=search_space,
319
+ label_mapping=label_mapping,
320
+ heuristic_info=heuristic_info,
321
+ num_hits=num_hits,
322
+ max_size_mask=max_size_mask,
323
+ disjoint_info=disjoint_info,
324
+ )
325
+
326
+ # # If we are in the last iteration, we set the beam size to 1 to get only the best formula
327
+ if previous_beam_length == length - 1:
328
+ beam_size = 1
329
+
330
+ # Perform beam search
331
+ next_beam_formulas, next_beam_iou, beam_visited = beam_search(
332
+ sorted_search_space,
333
+ masks=masks,
334
+ previous_beam=beam,
335
+ beam_masks=beam_masks,
336
+ bitmaps=bitmaps,
337
+ beam_limit=beam_size,
338
+ )
339
+
340
+ # Update top formulas
341
+ for index_beam in range(len(next_beam_formulas)):
342
+ beam.update({next_beam_formulas[index_beam]: next_beam_iou[index_beam]})
343
+
344
+ # Trim the beam
345
+ beam = dict(Counter(beam).most_common(beam_size))
346
+
347
+ # Update infos if there are step left
348
+ if previous_beam_length < length - 1:
349
+ beam_masks, (heuristic_info, label_mapping) = get_beam_info(
350
+ beam=beam.keys(),
351
+ masks=masks,
352
+ beam_masks=beam_masks,
353
+ heuristic_info=heuristic_info,
354
+ label_mapping=label_mapping,
355
+ bitmaps=bitmaps,
356
+ length=length,
357
+ )
358
+
359
+ top_result = Counter(beam).most_common(1)[0]
360
+
361
+ best_label = top_result[0]
362
+ return best_label
363
+
364
+
365
+ def compute_beam_explanations(
366
+ *, bitmaps, masks, disjoint_info, length, beam_size, device, cache_dir=None
367
+ ):
368
+ """Computes the beam explanations for the given bitmaps and masks.
369
+ Args:
370
+ bitmaps (torch.Tensor): A tensor of shape (N, H*W) where N is the number of sample.
371
+ masks (list): A list of concept masks. Each mask is a tensor of shape (H*W).
372
+ disjoint_info (dict): A dictionary containing the disjoint information for the concepts.
373
+ length (int): The maximum length of the formulas to search.
374
+ device (torch.device): The device to perform the computations on.
375
+ cache_dir (str, optional): The directory to store the cached information for the masks. If None, the information is not cached and is recomputed every time. Defaults to None.
376
+ Returns:
377
+ tuple: A tuple containing the best formula, its iou, the total number of visited nodes, the number of expanded nodes, and the number of estimated nodes.
378
+ """
379
+
380
+ max_size_mask = torch.numel(masks[0])
381
+ bitmaps = bitmaps.to(device)
382
+
383
+ # Get masks info
384
+ masks_info = mask_utils.get_dataset_info(masks, cache_dir)
385
+
386
+ # Get quantities
387
+ neuron_quantities, concept_quantities = optimal_utils.get_optimal_heuristic_info(
388
+ masks=masks, bitmaps=bitmaps, masks_quantities=masks_info
389
+ )
390
+
391
+ heuristic_info = (
392
+ masks_info,
393
+ neuron_quantities,
394
+ concept_quantities,
395
+ )
396
+ best_label = perform_search(
397
+ masks=masks,
398
+ bitmaps=bitmaps,
399
+ heuristic_info=heuristic_info,
400
+ disjoint_info=disjoint_info,
401
+ max_size_mask=max_size_mask,
402
+ beam_size=beam_size,
403
+ length=length,
404
+ )
405
+ return best_label