explainiverse 0.6.0__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,865 @@
1
+ # src/explainiverse/explainers/gradient/tcav.py
2
+ """
3
+ TCAV - Testing with Concept Activation Vectors.
4
+
5
+ TCAV provides human-interpretable explanations by quantifying how much
6
+ a model's predictions are influenced by high-level concepts. Instead of
7
+ attributing importance to individual features, TCAV explains which
8
+ human-understandable concepts (e.g., "striped", "furry") are important
9
+ for a model's predictions.
10
+
11
+ Key Components:
12
+ - Concept Activation Vectors (CAVs): Learned direction in activation space
13
+ that separates concept examples from random examples
14
+ - Directional Derivatives: Gradient of model output in CAV direction
15
+ - TCAV Score: Fraction of inputs where concept positively influences prediction
16
+ - Statistical Testing: Significance against random concepts
17
+
18
+ Reference:
19
+ Kim, B., Wattenberg, M., Gilmer, J., Cai, C., Wexler, J., Viegas, F., &
20
+ Sayres, R. (2018). Interpretability Beyond Feature Attribution:
21
+ Quantitative Testing with Concept Activation Vectors (TCAV).
22
+ ICML 2018. https://arxiv.org/abs/1711.11279
23
+
24
+ Example:
25
+ from explainiverse.explainers.gradient import TCAVExplainer
26
+ from explainiverse.adapters import PyTorchAdapter
27
+
28
+ adapter = PyTorchAdapter(model, task="classification")
29
+
30
+ explainer = TCAVExplainer(
31
+ model=adapter,
32
+ layer_name="layer3",
33
+ class_names=["zebra", "horse", "dog"]
34
+ )
35
+
36
+ # Learn a concept (e.g., "striped")
37
+ explainer.learn_concept(
38
+ concept_name="striped",
39
+ concept_examples=striped_images,
40
+ negative_examples=random_images
41
+ )
42
+
43
+ # Compute TCAV score for target class
44
+ result = explainer.compute_tcav_score(
45
+ test_inputs=test_images,
46
+ target_class=0, # zebra
47
+ concept_name="striped"
48
+ )
49
+ """
50
+
51
+ import numpy as np
52
+ from typing import List, Optional, Dict, Any, Union, Tuple
53
+ from collections import defaultdict
54
+
55
+ from explainiverse.core.explainer import BaseExplainer
56
+ from explainiverse.core.explanation import Explanation
57
+
58
+ # Check if sklearn is available for linear classifier
59
+ try:
60
+ from sklearn.linear_model import SGDClassifier, LogisticRegression
61
+ from sklearn.model_selection import train_test_split
62
+ from sklearn.metrics import accuracy_score
63
+ SKLEARN_AVAILABLE = True
64
+ except ImportError:
65
+ SKLEARN_AVAILABLE = False
66
+
67
+ # Check if scipy is available for statistical tests
68
+ try:
69
+ from scipy import stats
70
+ SCIPY_AVAILABLE = True
71
+ except ImportError:
72
+ SCIPY_AVAILABLE = False
73
+
74
+
75
+ def _check_dependencies():
76
+ """Check required dependencies for TCAV."""
77
+ if not SKLEARN_AVAILABLE:
78
+ raise ImportError(
79
+ "scikit-learn is required for TCAV. "
80
+ "Install it with: pip install scikit-learn"
81
+ )
82
+ if not SCIPY_AVAILABLE:
83
+ raise ImportError(
84
+ "scipy is required for TCAV statistical testing. "
85
+ "Install it with: pip install scipy"
86
+ )
87
+
88
+
89
+ class ConceptActivationVector:
90
+ """
91
+ Represents a learned Concept Activation Vector (CAV).
92
+
93
+ A CAV is the normal vector to the hyperplane that separates
94
+ concept examples from random (negative) examples in the
95
+ activation space of a neural network layer.
96
+
97
+ Attributes:
98
+ concept_name: Human-readable name of the concept
99
+ layer_name: Name of the layer this CAV was trained on
100
+ vector: The CAV direction (normal to separating hyperplane)
101
+ classifier: The trained linear classifier
102
+ accuracy: Classification accuracy on held-out data (Python float)
103
+ metadata: Additional training information
104
+ """
105
+
106
+ def __init__(
107
+ self,
108
+ concept_name: str,
109
+ layer_name: str,
110
+ vector: np.ndarray,
111
+ classifier: Any,
112
+ accuracy: float,
113
+ metadata: Optional[Dict[str, Any]] = None
114
+ ):
115
+ self.concept_name = concept_name
116
+ self.layer_name = layer_name
117
+ self.vector = vector / np.linalg.norm(vector) # Normalize
118
+ self.classifier = classifier
119
+ # Ensure accuracy is Python float, not numpy.float64
120
+ # sklearn.metrics.accuracy_score returns numpy.float64
121
+ self.accuracy = float(accuracy)
122
+ self.metadata = metadata or {}
123
+
124
+ def __repr__(self):
125
+ return (f"CAV(concept='{self.concept_name}', "
126
+ f"layer='{self.layer_name}', "
127
+ f"accuracy={self.accuracy:.3f})")
128
+
129
+
130
+ class TCAVExplainer(BaseExplainer):
131
+ """
132
+ TCAV (Testing with Concept Activation Vectors) explainer.
133
+
134
+ TCAV explains model predictions using high-level human concepts
135
+ rather than low-level features. It quantifies how sensitive a
136
+ model's predictions are to specific concepts.
137
+
138
+ The TCAV score for a concept C and class k is the fraction of
139
+ inputs of class k for which the model's prediction increases
140
+ when moving in the direction of concept C.
141
+
142
+ Attributes:
143
+ model: Model adapter with layer access (PyTorchAdapter)
144
+ layer_name: Target layer for activation extraction
145
+ class_names: List of class names
146
+ concepts: Dictionary of learned CAVs
147
+ random_concepts: Dictionary of random CAVs for statistical testing
148
+ """
149
+
150
+ def __init__(
151
+ self,
152
+ model,
153
+ layer_name: str,
154
+ class_names: Optional[List[str]] = None,
155
+ cav_classifier: str = "logistic",
156
+ random_seed: int = 42
157
+ ):
158
+ """
159
+ Initialize the TCAV explainer.
160
+
161
+ Args:
162
+ model: A model adapter with get_layer_output() and
163
+ get_layer_gradients() methods. Use PyTorchAdapter.
164
+ layer_name: Name of the layer to extract activations from.
165
+ Use model.list_layers() to see available layers.
166
+ class_names: List of class names for the model's outputs.
167
+ cav_classifier: Type of linear classifier for CAV training:
168
+ - "logistic": Logistic Regression (default)
169
+ - "sgd": SGD Classifier (faster for large data)
170
+ random_seed: Random seed for reproducibility.
171
+ """
172
+ _check_dependencies()
173
+ super().__init__(model)
174
+
175
+ # Validate model capabilities
176
+ if not hasattr(model, 'get_layer_output'):
177
+ raise TypeError(
178
+ "Model adapter must have get_layer_output() method. "
179
+ "Use PyTorchAdapter for PyTorch models."
180
+ )
181
+ if not hasattr(model, 'get_layer_gradients'):
182
+ raise TypeError(
183
+ "Model adapter must have get_layer_gradients() method. "
184
+ "Use PyTorchAdapter for PyTorch models."
185
+ )
186
+
187
+ self.layer_name = layer_name
188
+ self.class_names = list(class_names) if class_names else None
189
+ self.cav_classifier = cav_classifier
190
+ self.random_seed = random_seed
191
+
192
+ # Storage for learned concepts
193
+ self.concepts: Dict[str, ConceptActivationVector] = {}
194
+ self.random_concepts: Dict[str, List[ConceptActivationVector]] = defaultdict(list)
195
+
196
+ # Validate layer exists
197
+ if hasattr(model, 'list_layers'):
198
+ available_layers = model.list_layers()
199
+ if layer_name not in available_layers:
200
+ raise ValueError(
201
+ f"Layer '{layer_name}' not found. "
202
+ f"Available layers: {available_layers}"
203
+ )
204
+
205
+ def _get_activations(self, inputs: np.ndarray) -> np.ndarray:
206
+ """
207
+ Extract activations from the target layer.
208
+
209
+ Args:
210
+ inputs: Input data (n_samples, ...)
211
+
212
+ Returns:
213
+ Flattened activations (n_samples, n_features)
214
+ """
215
+ activations = self.model.get_layer_output(inputs, self.layer_name)
216
+
217
+ # Flatten activations if multi-dimensional (e.g., CNN feature maps)
218
+ if activations.ndim > 2:
219
+ # Global average pooling for spatial dimensions
220
+ # Shape: (batch, channels, height, width) -> (batch, channels)
221
+ activations = activations.mean(axis=tuple(range(2, activations.ndim)))
222
+
223
+ return activations
224
+
225
+ def _get_gradients_wrt_layer(
226
+ self,
227
+ inputs: np.ndarray,
228
+ target_class: int
229
+ ) -> np.ndarray:
230
+ """
231
+ Get gradients of output w.r.t. layer activations.
232
+
233
+ Args:
234
+ inputs: Input data
235
+ target_class: Target class index
236
+
237
+ Returns:
238
+ Gradients w.r.t. layer activations (n_samples, n_features)
239
+ """
240
+ activations, gradients = self.model.get_layer_gradients(
241
+ inputs, self.layer_name, target_class=target_class
242
+ )
243
+
244
+ # Flatten gradients if multi-dimensional
245
+ if gradients.ndim > 2:
246
+ gradients = gradients.mean(axis=tuple(range(2, gradients.ndim)))
247
+
248
+ return gradients
249
+
250
+ def _train_cav(
251
+ self,
252
+ concept_activations: np.ndarray,
253
+ negative_activations: np.ndarray,
254
+ test_size: float = 0.2
255
+ ) -> Tuple[np.ndarray, Any, float]:
256
+ """
257
+ Train a CAV (linear classifier) to separate concept from negative examples.
258
+
259
+ Args:
260
+ concept_activations: Activations for concept examples
261
+ negative_activations: Activations for negative examples
262
+ test_size: Fraction of data to use for accuracy estimation
263
+
264
+ Returns:
265
+ Tuple of (cav_vector, classifier, accuracy)
266
+ Note: accuracy is returned as Python float, not numpy.float64
267
+ """
268
+ np.random.seed(self.random_seed)
269
+
270
+ # Prepare training data
271
+ X = np.vstack([concept_activations, negative_activations])
272
+ y = np.array([1] * len(concept_activations) + [0] * len(negative_activations))
273
+
274
+ # Split for accuracy estimation
275
+ if test_size > 0 and len(X) >= 10:
276
+ X_train, X_test, y_train, y_test = train_test_split(
277
+ X, y, test_size=test_size, random_state=self.random_seed, stratify=y
278
+ )
279
+ else:
280
+ X_train, y_train = X, y
281
+ X_test, y_test = X, y
282
+
283
+ # Train classifier
284
+ if self.cav_classifier == "sgd":
285
+ classifier = SGDClassifier(
286
+ loss='hinge',
287
+ max_iter=1000,
288
+ random_state=self.random_seed,
289
+ n_jobs=-1
290
+ )
291
+ else: # logistic
292
+ classifier = LogisticRegression(
293
+ max_iter=1000,
294
+ random_state=self.random_seed,
295
+ n_jobs=-1,
296
+ solver='lbfgs'
297
+ )
298
+
299
+ classifier.fit(X_train, y_train)
300
+
301
+ # Compute accuracy and convert to Python float
302
+ # accuracy_score returns numpy.float64
303
+ accuracy = float(accuracy_score(y_test, classifier.predict(X_test)))
304
+
305
+ # Extract CAV (normal vector to separating hyperplane)
306
+ # For linear classifiers, this is the coefficient vector
307
+ cav_vector = classifier.coef_.flatten()
308
+
309
+ return cav_vector, classifier, accuracy
310
+
311
+ def learn_concept(
312
+ self,
313
+ concept_name: str,
314
+ concept_examples: np.ndarray,
315
+ negative_examples: Optional[np.ndarray] = None,
316
+ test_size: float = 0.2,
317
+ min_accuracy: float = 0.6
318
+ ) -> ConceptActivationVector:
319
+ """
320
+ Learn a Concept Activation Vector from examples.
321
+
322
+ The CAV is the direction in activation space that separates
323
+ concept examples from negative (non-concept) examples.
324
+
325
+ Args:
326
+ concept_name: Human-readable name for the concept.
327
+ concept_examples: Examples that exhibit the concept.
328
+ Shape: (n_concept, ...) matching model input.
329
+ negative_examples: Examples that don't exhibit the concept.
330
+ If None, random noise is used (not recommended).
331
+ Shape: (n_negative, ...) matching model input.
332
+ test_size: Fraction of data to hold out for accuracy estimation.
333
+ min_accuracy: Minimum accuracy for CAV to be considered valid.
334
+ Low accuracy suggests concept isn't linearly separable.
335
+
336
+ Returns:
337
+ The learned ConceptActivationVector.
338
+
339
+ Raises:
340
+ ValueError: If CAV accuracy is below min_accuracy threshold.
341
+ """
342
+ concept_examples = np.array(concept_examples)
343
+
344
+ if negative_examples is None:
345
+ # Generate random negative examples (not recommended)
346
+ import warnings
347
+ warnings.warn(
348
+ "No negative examples provided. Using random noise. "
349
+ "For meaningful CAVs, provide actual non-concept examples.",
350
+ UserWarning
351
+ )
352
+ negative_examples = np.random.randn(*concept_examples.shape).astype(np.float32)
353
+ else:
354
+ negative_examples = np.array(negative_examples)
355
+
356
+ # Extract activations
357
+ concept_acts = self._get_activations(concept_examples)
358
+ negative_acts = self._get_activations(negative_examples)
359
+
360
+ # Train CAV (accuracy is already Python float from _train_cav)
361
+ cav_vector, classifier, accuracy = self._train_cav(
362
+ concept_acts, negative_acts, test_size
363
+ )
364
+
365
+ if accuracy < min_accuracy:
366
+ raise ValueError(
367
+ f"CAV accuracy ({accuracy:.3f}) is below threshold ({min_accuracy}). "
368
+ f"The concept '{concept_name}' may not be linearly separable in "
369
+ f"layer '{self.layer_name}'. Consider using a different layer "
370
+ f"or providing more/better examples."
371
+ )
372
+
373
+ # Create CAV object (accuracy is already Python float)
374
+ cav = ConceptActivationVector(
375
+ concept_name=concept_name,
376
+ layer_name=self.layer_name,
377
+ vector=cav_vector,
378
+ classifier=classifier,
379
+ accuracy=accuracy,
380
+ metadata={
381
+ "n_concept_examples": int(len(concept_examples)),
382
+ "n_negative_examples": int(len(negative_examples)),
383
+ "test_size": float(test_size)
384
+ }
385
+ )
386
+
387
+ # Store the CAV
388
+ self.concepts[concept_name] = cav
389
+
390
+ return cav
391
+
392
+ def learn_random_concepts(
393
+ self,
394
+ negative_examples: np.ndarray,
395
+ n_random: int = 10,
396
+ concept_name_prefix: str = "_random"
397
+ ) -> List[ConceptActivationVector]:
398
+ """
399
+ Learn random CAVs for statistical significance testing.
400
+
401
+ Random CAVs are trained by splitting random examples into
402
+ two arbitrary groups. They serve as a baseline to test
403
+ whether a concept's TCAV score is significantly different
404
+ from random.
405
+
406
+ Args:
407
+ negative_examples: Pool of examples to sample from.
408
+ n_random: Number of random CAVs to train.
409
+ concept_name_prefix: Prefix for random concept names.
410
+
411
+ Returns:
412
+ List of random CAVs.
413
+ """
414
+ negative_examples = np.array(negative_examples)
415
+ random_cavs = []
416
+
417
+ # Get all activations
418
+ all_acts = self._get_activations(negative_examples)
419
+ n_samples = len(all_acts)
420
+
421
+ for i in range(n_random):
422
+ np.random.seed(self.random_seed + i)
423
+
424
+ # Randomly split into two groups
425
+ indices = np.random.permutation(n_samples)
426
+ split_point = n_samples // 2
427
+
428
+ group1_acts = all_acts[indices[:split_point]]
429
+ group2_acts = all_acts[indices[split_point:]]
430
+
431
+ # Train CAV on arbitrary split
432
+ try:
433
+ cav_vector, classifier, accuracy = self._train_cav(
434
+ group1_acts, group2_acts, test_size=0.0
435
+ )
436
+
437
+ cav = ConceptActivationVector(
438
+ concept_name=f"{concept_name_prefix}_{i}",
439
+ layer_name=self.layer_name,
440
+ vector=cav_vector,
441
+ classifier=classifier,
442
+ accuracy=accuracy, # Already Python float from _train_cav
443
+ metadata={"random_seed": int(self.random_seed + i)}
444
+ )
445
+
446
+ random_cavs.append(cav)
447
+ except Exception:
448
+ # Skip failed random CAVs
449
+ continue
450
+
451
+ # Store random CAVs
452
+ self.random_concepts[concept_name_prefix] = random_cavs
453
+
454
+ return random_cavs
455
+
456
+ def compute_directional_derivative(
457
+ self,
458
+ inputs: np.ndarray,
459
+ cav: ConceptActivationVector,
460
+ target_class: int
461
+ ) -> np.ndarray:
462
+ """
463
+ Compute directional derivative of predictions in CAV direction.
464
+
465
+ The directional derivative measures how the model's output for
466
+ the target class changes when moving in the CAV direction.
467
+
468
+ S_C,k(x) = ∇h_l,k(x) · v_C
469
+
470
+ where h_l,k is the model's logit for class k at layer l,
471
+ and v_C is the CAV direction.
472
+
473
+ Args:
474
+ inputs: Input data (n_samples, ...)
475
+ cav: The Concept Activation Vector
476
+ target_class: Target class index
477
+
478
+ Returns:
479
+ Directional derivatives as numpy array (n_samples,)
480
+ Note: Returns numpy array for efficient computation;
481
+ individual values should be converted to float if needed.
482
+ """
483
+ # Get gradients w.r.t. layer activations
484
+ gradients = self._get_gradients_wrt_layer(inputs, target_class)
485
+
486
+ # Compute dot product with CAV
487
+ # S_C,k(x) = ∇h_l,k(x) · v_C
488
+ directional_derivatives = np.dot(gradients, cav.vector)
489
+
490
+ return directional_derivatives
491
+
492
+ def compute_tcav_score(
493
+ self,
494
+ test_inputs: np.ndarray,
495
+ target_class: int,
496
+ concept_name: str,
497
+ return_derivatives: bool = False
498
+ ) -> Union[float, Tuple[float, np.ndarray]]:
499
+ """
500
+ Compute TCAV score for a concept and target class.
501
+
502
+ The TCAV score is the fraction of test inputs for which
503
+ the model's prediction for the target class increases
504
+ when moving in the concept direction.
505
+
506
+ TCAV_C,k = |{x : S_C,k(x) > 0}| / |X|
507
+
508
+ A score > 0.5 indicates the concept positively influences
509
+ the prediction, while < 0.5 indicates negative influence.
510
+
511
+ Args:
512
+ test_inputs: Test examples to compute TCAV score over.
513
+ target_class: Target class index.
514
+ concept_name: Name of the concept (must be learned first).
515
+ return_derivatives: If True, also return the directional derivatives.
516
+
517
+ Returns:
518
+ TCAV score as Python float in [0, 1].
519
+ If return_derivatives=True, returns (score, derivatives) where
520
+ score is Python float and derivatives is numpy array.
521
+ """
522
+ if concept_name not in self.concepts:
523
+ raise ValueError(
524
+ f"Concept '{concept_name}' not found. "
525
+ f"Available concepts: {list(self.concepts.keys())}. "
526
+ f"Use learn_concept() first."
527
+ )
528
+
529
+ test_inputs = np.array(test_inputs)
530
+ cav = self.concepts[concept_name]
531
+
532
+ # Compute directional derivatives
533
+ derivatives = self.compute_directional_derivative(
534
+ test_inputs, cav, target_class
535
+ )
536
+
537
+ # TCAV score = fraction with positive derivative
538
+ # np.mean returns numpy.float64, convert to Python float
539
+ tcav_score = float(np.mean(derivatives > 0))
540
+
541
+ if return_derivatives:
542
+ return tcav_score, derivatives
543
+ return tcav_score
544
+
545
+ def statistical_significance_test(
546
+ self,
547
+ test_inputs: np.ndarray,
548
+ target_class: int,
549
+ concept_name: str,
550
+ n_random: int = 10,
551
+ negative_examples: Optional[np.ndarray] = None,
552
+ alpha: float = 0.05
553
+ ) -> Dict[str, Any]:
554
+ """
555
+ Test statistical significance of TCAV score against random concepts.
556
+
557
+ Performs a two-sided t-test comparing the concept's TCAV score
558
+ against the distribution of random TCAV scores.
559
+
560
+ Args:
561
+ test_inputs: Test examples to compute TCAV scores over.
562
+ target_class: Target class index.
563
+ concept_name: Name of the concept to test.
564
+ n_random: Number of random concepts for comparison.
565
+ negative_examples: Examples for training random concepts.
566
+ If None, uses test_inputs (not ideal).
567
+ alpha: Significance level for the test.
568
+
569
+ Returns:
570
+ Dictionary containing (all values are Python native types):
571
+ - tcav_score: The concept's TCAV score (float)
572
+ - random_scores: List of random TCAV scores (list of float)
573
+ - random_mean: Mean of random scores (float)
574
+ - random_std: Std of random scores (float)
575
+ - t_statistic: t-statistic from t-test (float)
576
+ - p_value: p-value from t-test (float)
577
+ - significant: Whether the result is significant at level alpha (bool)
578
+ - effect_size: Cohen's d effect size (float)
579
+ - alpha: The significance level used (float)
580
+ """
581
+ test_inputs = np.array(test_inputs)
582
+
583
+ # Compute concept TCAV score (already Python float from compute_tcav_score)
584
+ concept_score = self.compute_tcav_score(
585
+ test_inputs, target_class, concept_name
586
+ )
587
+
588
+ # Train random CAVs if not already done
589
+ if negative_examples is None:
590
+ negative_examples = test_inputs
591
+
592
+ random_prefix = f"_random_{concept_name}_{target_class}"
593
+
594
+ if random_prefix not in self.random_concepts or \
595
+ len(self.random_concepts[random_prefix]) < n_random:
596
+ self.learn_random_concepts(
597
+ negative_examples,
598
+ n_random=n_random,
599
+ concept_name_prefix=random_prefix
600
+ )
601
+
602
+ random_cavs = self.random_concepts[random_prefix][:n_random]
603
+
604
+ # Compute random TCAV scores
605
+ random_scores = []
606
+ for random_cav in random_cavs:
607
+ derivatives = self.compute_directional_derivative(
608
+ test_inputs, random_cav, target_class
609
+ )
610
+ # Convert to Python float immediately
611
+ random_score = float(np.mean(derivatives > 0))
612
+ random_scores.append(random_score)
613
+
614
+ # Convert to numpy array for statistical computations
615
+ random_scores_array = np.array(random_scores)
616
+
617
+ # Perform one-sample t-test against concept score
618
+ # scipy.stats.ttest_1samp returns numpy scalars
619
+ t_stat_np, p_value_np = stats.ttest_1samp(random_scores_array, concept_score)
620
+
621
+ # Convert scipy/numpy results to Python native types
622
+ t_stat = float(t_stat_np)
623
+ p_value = float(p_value_np)
624
+
625
+ # Compute effect size (Cohen's d)
626
+ random_std = float(random_scores_array.std())
627
+ random_mean = float(random_scores_array.mean())
628
+
629
+ if random_std > 0:
630
+ effect_size = float((concept_score - random_mean) / random_std)
631
+ else:
632
+ # Handle zero std case
633
+ if concept_score != random_mean:
634
+ effect_size = float('inf') if concept_score > random_mean else float('-inf')
635
+ else:
636
+ effect_size = 0.0
637
+
638
+ # Compute significance as Python bool (not numpy.bool_)
639
+ significant = bool(p_value < alpha)
640
+
641
+ return {
642
+ "tcav_score": concept_score, # Already Python float
643
+ "random_scores": random_scores, # Already list of Python floats
644
+ "random_mean": random_mean,
645
+ "random_std": random_std,
646
+ "t_statistic": t_stat,
647
+ "p_value": p_value,
648
+ "significant": significant,
649
+ "effect_size": effect_size,
650
+ "alpha": float(alpha)
651
+ }
652
+
653
+ def explain(
654
+ self,
655
+ test_inputs: np.ndarray,
656
+ target_class: Optional[int] = None,
657
+ concept_names: Optional[List[str]] = None,
658
+ run_significance_test: bool = False,
659
+ negative_examples: Optional[np.ndarray] = None,
660
+ n_random: int = 10
661
+ ) -> Explanation:
662
+ """
663
+ Generate TCAV explanation for test inputs.
664
+
665
+ Computes TCAV scores for all (or specified) concepts
666
+ and optionally runs statistical significance tests.
667
+
668
+ Args:
669
+ test_inputs: Input examples to explain.
670
+ target_class: Target class to explain. If None, uses
671
+ the most common predicted class.
672
+ concept_names: List of concepts to include. If None,
673
+ uses all learned concepts.
674
+ run_significance_test: Whether to run statistical tests.
675
+ negative_examples: Examples for random CAVs (for significance test).
676
+ n_random: Number of random concepts for significance test.
677
+
678
+ Returns:
679
+ Explanation object with TCAV scores for each concept.
680
+ """
681
+ test_inputs = np.array(test_inputs)
682
+
683
+ if len(self.concepts) == 0:
684
+ raise ValueError(
685
+ "No concepts learned. Use learn_concept() first."
686
+ )
687
+
688
+ # Determine target class
689
+ if target_class is None:
690
+ predictions = self.model.predict(test_inputs)
691
+ if predictions.ndim > 1:
692
+ target_class = int(np.argmax(np.bincount(
693
+ np.argmax(predictions, axis=1)
694
+ )))
695
+ else:
696
+ target_class = 0
697
+
698
+ # Determine concepts to analyze
699
+ if concept_names is None:
700
+ concept_names = list(self.concepts.keys())
701
+
702
+ # Compute TCAV scores for each concept
703
+ tcav_scores = {}
704
+ significance_results = {}
705
+
706
+ for concept_name in concept_names:
707
+ if concept_name not in self.concepts:
708
+ continue
709
+
710
+ # compute_tcav_score returns Python float
711
+ score, derivatives = self.compute_tcav_score(
712
+ test_inputs, target_class, concept_name,
713
+ return_derivatives=True
714
+ )
715
+
716
+ tcav_scores[concept_name] = {
717
+ "score": score, # Already Python float
718
+ "cav_accuracy": self.concepts[concept_name].accuracy, # Already Python float
719
+ "positive_count": int(np.sum(derivatives > 0)),
720
+ "total_count": int(len(derivatives))
721
+ }
722
+
723
+ # Optionally run significance test
724
+ if run_significance_test:
725
+ neg_examples = negative_examples if negative_examples is not None else test_inputs
726
+ # statistical_significance_test returns dict with Python native types
727
+ sig_result = self.statistical_significance_test(
728
+ test_inputs, target_class, concept_name,
729
+ n_random=n_random,
730
+ negative_examples=neg_examples
731
+ )
732
+ significance_results[concept_name] = sig_result
733
+
734
+ # Determine class name
735
+ if self.class_names and target_class is not None:
736
+ label_name = self.class_names[target_class]
737
+ else:
738
+ label_name = f"class_{target_class}"
739
+
740
+ explanation_data = {
741
+ "tcav_scores": tcav_scores,
742
+ "target_class": int(target_class),
743
+ "n_test_inputs": int(len(test_inputs)),
744
+ "layer_name": self.layer_name,
745
+ "concepts_analyzed": list(concept_names) # Ensure it's a list
746
+ }
747
+
748
+ if run_significance_test:
749
+ explanation_data["significance_tests"] = significance_results
750
+
751
+ return Explanation(
752
+ explainer_name="TCAV",
753
+ target_class=label_name,
754
+ explanation_data=explanation_data
755
+ )
756
+
757
+ def explain_batch(
758
+ self,
759
+ X: np.ndarray,
760
+ target_class: Optional[int] = None,
761
+ **kwargs
762
+ ) -> List[Explanation]:
763
+ """
764
+ TCAV typically explains batches together, not individually.
765
+
766
+ For TCAV, it makes more sense to analyze a batch of inputs
767
+ together to compute meaningful TCAV scores. This method
768
+ returns a single explanation for the batch.
769
+
770
+ Args:
771
+ X: Batch of inputs.
772
+ target_class: Target class to explain.
773
+ **kwargs: Additional arguments passed to explain().
774
+
775
+ Returns:
776
+ List containing a single Explanation for the batch.
777
+ """
778
+ return [self.explain(X, target_class=target_class, **kwargs)]
779
+
780
+ def get_most_influential_concepts(
781
+ self,
782
+ test_inputs: np.ndarray,
783
+ target_class: int,
784
+ top_k: int = 5
785
+ ) -> List[Tuple[str, float]]:
786
+ """
787
+ Get the most influential concepts for the target class.
788
+
789
+ Ranks concepts by how much they positively influence
790
+ the model's prediction for the target class.
791
+
792
+ Args:
793
+ test_inputs: Test examples.
794
+ target_class: Target class index.
795
+ top_k: Number of top concepts to return.
796
+
797
+ Returns:
798
+ List of (concept_name, tcav_score) tuples, sorted by score descending.
799
+ All scores are Python floats.
800
+ """
801
+ scores = []
802
+
803
+ for concept_name in self.concepts:
804
+ # compute_tcav_score returns Python float
805
+ score = self.compute_tcav_score(
806
+ test_inputs, target_class, concept_name
807
+ )
808
+ scores.append((concept_name, score))
809
+
810
+ # Sort by score (higher = more positive influence)
811
+ scores.sort(key=lambda x: x[1], reverse=True)
812
+
813
+ return scores[:top_k]
814
+
815
+ def compare_concepts(
816
+ self,
817
+ test_inputs: np.ndarray,
818
+ target_classes: List[int],
819
+ concept_names: Optional[List[str]] = None
820
+ ) -> Dict[str, Dict[int, float]]:
821
+ """
822
+ Compare TCAV scores across multiple target classes.
823
+
824
+ Useful for understanding which concepts are important
825
+ for different classes.
826
+
827
+ Args:
828
+ test_inputs: Test examples.
829
+ target_classes: List of class indices to compare.
830
+ concept_names: Concepts to analyze (default: all).
831
+
832
+ Returns:
833
+ Dictionary mapping concept names to {class_idx: tcav_score}.
834
+ All scores are Python floats.
835
+ """
836
+ if concept_names is None:
837
+ concept_names = list(self.concepts.keys())
838
+
839
+ results: Dict[str, Dict[int, float]] = {}
840
+
841
+ for concept_name in concept_names:
842
+ results[concept_name] = {}
843
+ for class_idx in target_classes:
844
+ # compute_tcav_score returns Python float
845
+ score = self.compute_tcav_score(
846
+ test_inputs, class_idx, concept_name
847
+ )
848
+ results[concept_name][class_idx] = score
849
+
850
+ return results
851
+
852
+ def list_concepts(self) -> List[str]:
853
+ """List all learned concept names."""
854
+ return list(self.concepts.keys())
855
+
856
+ def get_concept(self, concept_name: str) -> ConceptActivationVector:
857
+ """Get a specific CAV by name."""
858
+ if concept_name not in self.concepts:
859
+ raise ValueError(f"Concept '{concept_name}' not found.")
860
+ return self.concepts[concept_name]
861
+
862
+ def remove_concept(self, concept_name: str) -> None:
863
+ """Remove a learned concept."""
864
+ if concept_name in self.concepts:
865
+ del self.concepts[concept_name]