nous 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nous might be problematic. Click here for more details.

nous/explain/loo.py ADDED
@@ -0,0 +1,130 @@
1
+ from __future__ import annotations
2
+ import numpy as np
3
+ import pandas as pd
4
+ import torch
5
+ from typing import Optional, Sequence
6
+ from ..model import NousNet
7
+ from .facts_desc import render_fact_descriptions
8
+ from .aggregator import format_agg_mixture
9
+
10
+ def rule_impact_df(
11
+ model: NousNet, x_sample, feature_names: Sequence[str], class_names: Optional[Sequence[str]] = None, vs_class: Optional[int] = None,
12
+ loo_mode: str = 'replacement', # 'replacement' | 'frozen' | 'active_only'
13
+ top_m_rules: Optional[int] = None,
14
+ use_pruning: bool = False, pruning_threshold: float = 0.0, use_pre_norm: bool = False,
15
+ base_state=None, device=None
16
+ ) -> pd.DataFrame:
17
+ """
18
+ Honest leave-one-out at the rule level with recomputed gating (before top-k),
19
+ supporting multiple modes (replacement, frozen, active_only).
20
+ """
21
+ device = device or next(model.parameters()).device
22
+ task = model.config['task_type']
23
+
24
+ if base_state is None:
25
+ base_probas, base_logits, base_internals = model.forward_explain(
26
+ x_sample, apply_pruning=use_pruning, pruning_threshold=pruning_threshold, device=device
27
+ )
28
+ else:
29
+ base_probas, base_logits, base_internals = base_state
30
+
31
+ if task == "classification":
32
+ base_logits = np.array(base_logits)
33
+ pred_idx = int(np.argmax(base_probas))
34
+ if vs_class is None:
35
+ runner_up_idx = int(np.argsort(base_logits)[-2]) if base_logits.size > 1 else pred_idx
36
+ else:
37
+ runner_up_idx = int(vs_class)
38
+ base_margin = float(base_logits[pred_idx] - base_logits[runner_up_idx])
39
+ else:
40
+ base_pred = float(base_logits[0])
41
+
42
+ restrict_masks = None
43
+ if loo_mode in ('frozen', 'active_only'):
44
+ # Freeze current active set by extracting gate masks and reusing as restrict masks.
45
+ restrict_masks = []
46
+ block_keys = sorted([k for k in base_internals.keys() if k.startswith("block_")], key=lambda s: int(s.split("_")[1]))
47
+ for key in block_keys:
48
+ gm = base_internals[key]['gate_mask']
49
+ if isinstance(gm, torch.Tensor):
50
+ gm = gm.squeeze(0)
51
+ restrict_masks.append((gm > 0).float())
52
+
53
+ fact_desc = render_fact_descriptions(model, feature_names)
54
+
55
+ rows = []
56
+ block_keys = sorted([k for k in base_internals.keys() if k.startswith("block_")], key=lambda s: int(s.split("_")[1]))
57
+ for b_idx, key in enumerate(block_keys):
58
+ details = base_internals[key]
59
+ metric_tensor = details['pre_norm_sum'] if (use_pre_norm and 'pre_norm_sum' in details) else details['gated_activations']
60
+ ga_np = metric_tensor.squeeze(0).abs().cpu().numpy() if isinstance(metric_tensor, torch.Tensor) else np.abs(metric_tensor)
61
+ active_rules = np.where(ga_np > 1e-12)[0]
62
+ if len(active_rules) == 0:
63
+ continue
64
+
65
+ if top_m_rules is not None and len(active_rules) > top_m_rules:
66
+ order = np.argsort(-ga_np[active_rules])
67
+ active_rules = active_rules[order[:top_m_rules]]
68
+
69
+ agg_w = details['aggregator_weights']
70
+ if isinstance(agg_w, torch.Tensor):
71
+ agg_w = agg_w.cpu()
72
+
73
+ facts_used = details.get('facts_used', None)
74
+ if isinstance(facts_used, torch.Tensor):
75
+ facts_used = facts_used.cpu().numpy()
76
+
77
+ for r in active_rules:
78
+ drop_spec = (b_idx, int(r))
79
+ restr = None if loo_mode == 'replacement' else restrict_masks
80
+
81
+ drop_probas, drop_logits, _ = model.forward_explain(
82
+ x_sample, drop_rule_spec=drop_spec, restrict_masks=restr,
83
+ apply_pruning=use_pruning, pruning_threshold=pruning_threshold, device=device
84
+ )
85
+
86
+ if agg_w is not None:
87
+ aggs = format_agg_mixture(agg_w[r])
88
+ else:
89
+ aggs = "AND (fixed)"
90
+
91
+ facts_str = "—"
92
+ if facts_used is not None and facts_used.shape[0] > r:
93
+ used = facts_used[r]
94
+ if np.ndim(used) == 0:
95
+ used = [int(used)]
96
+ facts_str = ", ".join([f"F{int(fid)+1}" for fid in used])
97
+ facts_str += " → " + " | ".join([f"[F{int(fid)+1}] {fact_desc[int(fid)]}" for fid in used])
98
+
99
+ if task == "classification":
100
+ drop_logits = np.array(drop_logits)
101
+ dlogit = float(base_logits[pred_idx] - drop_logits[pred_idx])
102
+ drop_margin = float(drop_logits[pred_idx] - drop_logits[runner_up_idx])
103
+ dmargin = float(base_margin - drop_margin)
104
+ critical = (np.argmax(drop_probas) != pred_idx)
105
+ supports = dmargin > 0
106
+ rows.append({
107
+ "block": b_idx+1, "rule": int(r+1),
108
+ "aggregators": aggs, "facts": facts_str,
109
+ "Δlogit(pred)": dlogit,
110
+ f"Δmargin(vs {class_names[runner_up_idx] if class_names else runner_up_idx})": dmargin,
111
+ "supports_pred": supports, "critical_flip": critical
112
+ })
113
+ else:
114
+ drop_pred = float(drop_logits[0])
115
+ d_pred = base_pred - drop_pred
116
+ rows.append({
117
+ "block": b_idx+1, "rule": int(r+1),
118
+ "aggregators": aggs, "facts": facts_str,
119
+ "Δprediction": d_pred
120
+ })
121
+
122
+ df = pd.DataFrame(rows)
123
+ if not df.empty:
124
+ if task == "classification":
125
+ key = [c for c in df.columns if c.startswith("Δmargin(")][0]
126
+ df['abs_impact'] = df[key].abs()
127
+ else:
128
+ df['abs_impact'] = df['Δprediction'].abs()
129
+ df = df.sort_values(by='abs_impact', ascending=False).drop(columns=['abs_impact']).reset_index(drop=True)
130
+ return df
nous/explain/mse.py ADDED
@@ -0,0 +1,93 @@
1
+ from __future__ import annotations
2
+ import numpy as np
3
+ import torch
4
+ from typing import Optional, Sequence, Dict, Any, List
5
+ from ..model import NousNet
6
+ from .loo import rule_impact_df
7
+
8
+ def minimal_sufficient_explanation(
9
+ model: NousNet, x_sample, feature_names: Sequence[str], class_names: Optional[Sequence[str]] = None,
10
+ margin_tolerance: float = 0.0, prob_tolerance: float = 0.0, pred_tolerance_reg: float = 0.02,
11
+ freeze_non_active: bool = True, device=None
12
+ ) -> Dict[str, Any]:
13
+ """
14
+ Greedy backward elimination among originally active rules.
15
+ Preserves the prediction (class unchanged with margins within tolerance; regression within abs tol).
16
+ """
17
+ device = device or next(model.parameters()).device
18
+ task = model.config['task_type']
19
+
20
+ base_probas, base_logits, base_internals = model.forward_explain(x_sample, device=device)
21
+ if task == "classification":
22
+ pred_idx = int(np.argmax(base_probas))
23
+ runner_up = int(np.argsort(base_logits)[-2]) if base_logits.size > 1 else pred_idx
24
+ base_margin = float(base_logits[pred_idx] - base_logits[runner_up])
25
+ base_conf = float(base_probas[pred_idx])
26
+ else:
27
+ base_pred = float(base_logits[0])
28
+
29
+ # Extract active set (frozen)
30
+ active_masks: List[torch.Tensor] = []
31
+ block_keys = sorted([k for k in base_internals.keys() if k.startswith("block_")], key=lambda s: int(s.split("_")[1]))
32
+ for key in block_keys:
33
+ gm = base_internals[key]['gate_mask']
34
+ if isinstance(gm, torch.Tensor):
35
+ gm = gm.squeeze(0)
36
+ active_masks.append((gm > 0).float())
37
+ current_masks = [m.clone() for m in active_masks]
38
+
39
+ # Rank removal candidates by (absolute) impact ascending
40
+ imp = rule_impact_df(
41
+ model, x_sample, feature_names, class_names=class_names,
42
+ loo_mode='frozen', device=device
43
+ )
44
+ if imp.empty:
45
+ return {
46
+ "kept_masks": current_masks, "kept": [],
47
+ "removed": [], "pred_preserved": True,
48
+ "size": sum(int(m.sum().item()) for m in current_masks)
49
+ }
50
+
51
+ if task == "classification":
52
+ margin_col = [c for c in imp.columns if c.startswith("Δmargin(")][0]
53
+ imp = imp.assign(order=imp[margin_col].abs()).sort_values("order", ascending=True).drop(columns=["order"])
54
+ else:
55
+ imp = imp.assign(order=imp["Δprediction"].abs()).sort_values("order", ascending=True).drop(columns=["order"])
56
+
57
+ removed = []
58
+ for _, row in imp.iterrows():
59
+ b = int(row["block"]) - 1
60
+ r = int(row["rule"]) - 1
61
+ if current_masks[b][r] == 0:
62
+ continue
63
+
64
+ proposal = [m.clone() for m in current_masks]
65
+ proposal[b][r] = 0.0
66
+ probas2, logits2, _ = model.forward_explain(x_sample, restrict_masks=proposal, device=device)
67
+
68
+ if task == "classification":
69
+ new_pred = int(np.argmax(probas2))
70
+ if new_pred != pred_idx:
71
+ continue
72
+ new_margin = float(logits2[pred_idx] - logits2[runner_up])
73
+ new_conf = float(probas2[pred_idx])
74
+ if new_margin < base_margin - margin_tolerance:
75
+ continue
76
+ if prob_tolerance > 0.0 and (new_conf < base_conf - prob_tolerance):
77
+ continue
78
+ else:
79
+ new_pred = float(logits2[0])
80
+ if abs(new_pred - base_pred) > pred_tolerance_reg:
81
+ continue
82
+
83
+ current_masks = proposal
84
+ removed.append((b+1, r+1))
85
+
86
+ kept = [(i+1, int(idx.item())+1) for i, m in enumerate(current_masks) for idx in torch.where(m > 0)[0]]
87
+ return {
88
+ "kept_masks": current_masks,
89
+ "kept": kept,
90
+ "removed": removed,
91
+ "pred_preserved": True,
92
+ "size": sum(int(m.sum().item()) for m in current_masks)
93
+ }
@@ -0,0 +1,117 @@
1
+ from __future__ import annotations
2
+ import numpy as np
3
+ from typing import Optional
4
+ from sklearn.metrics import mean_absolute_error
5
+ from ..model import NousNet
6
+
7
+ def select_pruning_threshold_global(
8
+ model: NousNet, X_val, target_fidelity: float = 0.99, task_type: Optional[str] = None,
9
+ metric_reg: str = "mae", tol_reg: float = 0.05, pointwise_tol: Optional[float] = None,
10
+ max_samples: int = 500, device=None
11
+ ) -> float:
12
+ """
13
+ Grid selection of activation threshold (post-gating) keeping fidelity (classification) or MAE (regression).
14
+ """
15
+ device = device or next(model.parameters()).device
16
+ task = task_type or model.config['task_type']
17
+
18
+ n = min(len(X_val), max_samples)
19
+ Xv = X_val[:n]
20
+
21
+ base_preds = []
22
+ acts = []
23
+ for i in range(n):
24
+ _, logit_b, internals = model.forward_explain(Xv[i], device=device)
25
+ base_preds.append(int(np.argmax(logit_b)) if task == "classification" else float(logit_b[0]))
26
+ for key in [k for k in internals.keys() if k.startswith("block_")]:
27
+ ga = internals[key]['gated_activations'].abs().cpu().numpy().ravel()
28
+ acts.extend(list(ga))
29
+ acts = np.array(acts)
30
+ if acts.size == 0:
31
+ return 0.0
32
+
33
+ qs = np.linspace(0.5, 0.999, 25)
34
+ candidates = np.unique(np.quantile(acts, qs))
35
+
36
+ best_t = 0.0
37
+ for t in candidates:
38
+ if task == "classification":
39
+ agree = 0
40
+ for i in range(n):
41
+ _, logit_p, _ = model.forward_explain(Xv[i], apply_pruning=True, pruning_threshold=float(t), device=device)
42
+ agree += int(int(np.argmax(logit_p)) == base_preds[i])
43
+ fidelity = agree / n
44
+ if fidelity >= target_fidelity:
45
+ best_t = float(t)
46
+ else:
47
+ preds_p, preds_b = [], []
48
+ max_abs = 0.0
49
+ for i in range(n):
50
+ _, logit_p, _ = model.forward_explain(Xv[i], apply_pruning=True, pruning_threshold=float(t), device=device)
51
+ _, logit_b, _ = model.forward_explain(Xv[i], device=device)
52
+ pv, bv = float(logit_p[0]), float(logit_b[0])
53
+ preds_p.append(pv); preds_b.append(bv)
54
+ max_abs = max(max_abs, abs(pv - bv))
55
+ mae_p = mean_absolute_error(preds_b, preds_p)
56
+ ok_mae = (mae_p <= tol_reg)
57
+ ok_point = True if pointwise_tol is None else (max_abs <= pointwise_tol)
58
+ if ok_mae and ok_point:
59
+ best_t = float(t)
60
+ return best_t
61
+
62
+
63
+ def select_pruning_threshold_global_bs(
64
+ model: NousNet, X_val, target_fidelity: float = 0.99, task_type: Optional[str] = None,
65
+ metric_reg: str = "mae", tol_reg: float = 0.05, pointwise_tol: Optional[float] = None,
66
+ max_samples: int = 500, device=None
67
+ ) -> float:
68
+ """
69
+ Binary search selection of activation threshold (post-gating) with fidelity/MAE constraints.
70
+ """
71
+ device = device or next(model.parameters()).device
72
+ task = task_type or model.config['task_type']
73
+ n = min(len(X_val), max_samples)
74
+ Xv = X_val[:n]
75
+
76
+ base_preds = []
77
+ acts = []
78
+ for i in range(n):
79
+ _, logit_b, internals = model.forward_explain(Xv[i], device=device)
80
+ base_preds.append(int(np.argmax(logit_b)) if task == "classification" else float(logit_b[0]))
81
+ for key in [k for k in internals.keys() if k.startswith("block_")]:
82
+ ga = internals[key]['gated_activations'].abs().cpu().numpy().ravel()
83
+ acts.extend(list(ga))
84
+ acts = np.array(acts)
85
+ if acts.size == 0:
86
+ return 0.0
87
+
88
+ lo, hi = 0.0, float(np.quantile(acts, 0.999))
89
+ best_t = 0.0
90
+ for _ in range(14):
91
+ mid = (lo + hi) / 2.0
92
+ if task == "classification":
93
+ agree = 0
94
+ for i in range(n):
95
+ _, logit_p, _ = model.forward_explain(Xv[i], apply_pruning=True, pruning_threshold=float(mid), device=device)
96
+ agree += int(int(np.argmax(logit_p)) == base_preds[i])
97
+ fidelity = agree / n
98
+ if fidelity >= target_fidelity:
99
+ best_t = mid; lo = mid
100
+ else:
101
+ hi = mid
102
+ else:
103
+ preds_p, preds_b, max_abs = [], [], 0.0
104
+ for i in range(n):
105
+ _, logit_p, _ = model.forward_explain(Xv[i], apply_pruning=True, pruning_threshold=float(mid), device=device)
106
+ _, logit_b, _ = model.forward_explain(Xv[i], device=device)
107
+ pv, bv = float(logit_p[0]), float(logit_b[0])
108
+ preds_p.append(pv); preds_b.append(bv)
109
+ max_abs = max(max_abs, abs(pv - bv))
110
+ mae_p = mean_absolute_error(preds_b, preds_p)
111
+ ok_mae = (mae_p <= tol_reg)
112
+ ok_point = True if pointwise_tol is None else (max_abs <= pointwise_tol)
113
+ if ok_mae and ok_point:
114
+ best_t = mid; lo = mid
115
+ else:
116
+ hi = mid
117
+ return float(best_t)
@@ -0,0 +1,42 @@
1
+ from __future__ import annotations
2
+ import numpy as np
3
+ import torch
4
+ from typing import Optional, Sequence, Dict, Any
5
+ from ..model import NousNet
6
+ from .loo import rule_impact_df
7
+
8
+ def explanation_stability(
9
+ model: NousNet, x_sample, feature_names: Sequence[str], class_names: Optional[Sequence[str]] = None,
10
+ k_top: int = 5, sigma: float = 0.05, trials: int = 20, loo_mode: str = 'frozen',
11
+ use_pruning: bool = False, pruning_threshold: float = 0.0, device=None
12
+ ) -> Dict[str, Any]:
13
+ """
14
+ Stability of top-k rule explanations under small input perturbations.
15
+ Returns mean/std Jaccard overlap vs base top-k selection.
16
+ """
17
+ device = device or next(model.parameters()).device
18
+ base_imp = rule_impact_df(
19
+ model, x_sample, feature_names, class_names=class_names,
20
+ loo_mode=loo_mode, use_pruning=use_pruning, pruning_threshold=pruning_threshold
21
+ )
22
+ base_set = set((int(r.block), int(r.rule)) for _, r in base_imp.head(k_top).iterrows()) if not base_imp.empty else set()
23
+
24
+ overlaps = []
25
+ for _ in range(trials):
26
+ x = torch.tensor(x_sample, dtype=torch.float32, device=device)
27
+ noise = torch.randn_like(x) * sigma
28
+ x_noisy = (x + noise).cpu().numpy()
29
+
30
+ imp = rule_impact_df(
31
+ model, x_noisy, feature_names, class_names=class_names,
32
+ loo_mode=loo_mode, use_pruning=use_pruning, pruning_threshold=pruning_threshold
33
+ )
34
+ cur_set = set((int(r.block), int(r.rule)) for _, r in imp.head(k_top).iterrows()) if not imp.empty else set()
35
+ inter = len(base_set & cur_set)
36
+ union = len(base_set | cur_set) if (base_set | cur_set) else 1
37
+ overlaps.append(inter / union)
38
+ return {
39
+ "mean_jaccard": float(np.mean(overlaps)),
40
+ "std_jaccard": float(np.std(overlaps)),
41
+ "base_top_rules": list(base_set)
42
+ }
nous/explain/traces.py ADDED
@@ -0,0 +1,285 @@
1
+ from __future__ import annotations
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from typing import Optional, Sequence, List, Dict, Any
6
+ from ..model import NousNet
7
+ from ..prototypes import ScaledPrototypeLayer
8
+ from ..explain.aggregator import AGG_NAMES
9
+ from ..explain.facts_desc import render_fact_descriptions
10
+
11
+ @torch.no_grad()
12
+ def get_last_block_static_metadata(model: NousNet, top_k_facts_override: Optional[int] = None):
13
+ """
14
+ Returns:
15
+ - agg_weights: torch.Tensor [R, A] or None
16
+ - facts_used: np.ndarray [R, K] (indices of inputs to the last block) or None
17
+ """
18
+ blk = model.blocks[-1]
19
+ agg = None
20
+ facts_used = None
21
+
22
+ if hasattr(blk, "aggregator_logits"):
23
+ agg = F.softmax(blk.aggregator_logits, dim=1).detach().cpu()
24
+
25
+ # Fact indices used per rule depend on block type
26
+ from ..rules.blocks import SimpleNousBlock
27
+ from ..rules.softmax import SoftmaxRuleLayer
28
+ from ..rules.sparse import SparseRuleLayer
29
+
30
+ if isinstance(blk, SimpleNousBlock):
31
+ facts_used = blk.rule.idx.detach().cpu().numpy() # [R, 2]
32
+ elif isinstance(blk, SoftmaxRuleLayer):
33
+ k = top_k_facts_override or getattr(blk, "top_k_facts", 2)
34
+ fl = F.softmax(blk.fact_logits, dim=1)
35
+ _, topk = torch.topk(fl, k=min(k, blk.input_dim), dim=1)
36
+ facts_used = topk.detach().cpu().numpy() # [R, k]
37
+ elif isinstance(blk, SparseRuleLayer):
38
+ k = top_k_facts_override or getattr(blk, "top_k_facts", 2)
39
+ prob = blk.hard_concrete.get_proba().detach()
40
+ _, topk = torch.topk(prob, k=min(k, blk.input_dim), dim=1)
41
+ facts_used = topk.cpu().numpy()
42
+ return agg, facts_used
43
+
44
+ @torch.no_grad()
45
+ def _block_fact_mapping(block, top_k_per_rule: int = 2):
46
+ """
47
+ Returns np.ndarray [R, k] — input indices of the given block for each rule.
48
+ """
49
+ from ..rules.blocks import SimpleNousBlock
50
+ from ..rules.softmax import SoftmaxRuleLayer
51
+ from ..rules.sparse import SparseRuleLayer
52
+
53
+ if isinstance(block, SimpleNousBlock):
54
+ return block.rule.idx.detach().cpu().numpy()
55
+ elif isinstance(block, SoftmaxRuleLayer):
56
+ fl = F.softmax(block.fact_logits, dim=1)
57
+ _, topk = torch.topk(fl, k=min(top_k_per_rule, block.input_dim), dim=1)
58
+ return topk.detach().cpu().numpy()
59
+ elif isinstance(block, SparseRuleLayer):
60
+ prob = block.hard_concrete.get_proba().detach()
61
+ _, topk = torch.topk(prob, k=min(top_k_per_rule, block.input_dim), dim=1)
62
+ return topk.cpu().numpy()
63
+ else:
64
+ return None
65
+
66
+ @torch.no_grad()
67
+ def trace_rule_to_base_facts(model: NousNet, rule_idx_last: int, top_k_per_step: int = 2) -> List[int]:
68
+ """
69
+ Trace a rule from the last block down to base β-facts.
70
+ """
71
+ selected = {int(rule_idx_last)}
72
+ for b in reversed(range(len(model.blocks))):
73
+ blk = model.blocks[b]
74
+ mapping = _block_fact_mapping(blk, top_k_per_rule=top_k_per_step)
75
+ if mapping is None:
76
+ return []
77
+ prev_sel = set()
78
+ for r in selected:
79
+ if r < 0 or r >= mapping.shape[0]:
80
+ continue
81
+ prev_sel.update([int(i) for i in mapping[r].tolist()])
82
+ selected = prev_sel
83
+ return sorted(selected)
84
+
85
+ @torch.no_grad()
86
+ def prototype_top_rules(model: NousNet, proto_id: int, top_k_rules: int = 10):
87
+ """
88
+ Returns list of (rule_idx, weight) from the prototype vector sorted by |weight|.
89
+ """
90
+ assert isinstance(model.head, ScaledPrototypeLayer), "Prototypes head is not enabled."
91
+ P = model.head.prototypes.detach().cpu().numpy()
92
+ v = P[proto_id]
93
+ idx = np.argsort(-np.abs(v))[:top_k_rules]
94
+ return [(int(i), float(v[i])) for i in idx]
95
+
96
+ @torch.no_grad()
97
+ def prototype_contribution_df(model: NousNet, x_sample, class_names: Optional[Sequence[str]] = None, top_k: int = 5, device=None):
98
+ """
99
+ For a single example — contribution of prototypes to the predicted class.
100
+ """
101
+ import pandas as pd
102
+
103
+ assert isinstance(model.head, ScaledPrototypeLayer), "Prototypes head is not enabled."
104
+ device = device or next(model.parameters()).device
105
+
106
+ h = model.encode(np.array([x_sample]), device=device)
107
+ d, act = model.head.compute_dist_act(h.to(device))
108
+ act = act.squeeze(0).cpu().numpy()
109
+ d = d.squeeze(0).cpu().numpy()
110
+
111
+ W = model.head.prototype_class.detach().cpu().numpy() # [M, C]
112
+ probs, logits, _ = model.forward_explain(x_sample, device=device)
113
+ c_hat = int(np.argmax(probs))
114
+ w_c = W[:, c_hat]
115
+ contrib = act * w_c
116
+
117
+ Wsm = F.softmax(model.head.prototype_class, dim=1).detach().cpu().numpy()
118
+ ent = (-Wsm * np.clip(np.log(Wsm + 1e-12), -50, 50)).sum(axis=1)
119
+ primary = np.argmax(Wsm, axis=1)
120
+
121
+ order = np.argsort(-np.abs(contrib))[:top_k]
122
+ rows = []
123
+ for j in order:
124
+ rows.append({
125
+ "proto": int(j),
126
+ "distance": float(d[j]),
127
+ "activation": float(act[j]),
128
+ "w_c": float(w_c[j]),
129
+ "contribution": float(contrib[j]),
130
+ "primary_class": (class_names[primary[j]] if class_names is not None else int(primary[j])),
131
+ "entropy": float(ent[j])
132
+ })
133
+ import pandas as pd
134
+ return pd.DataFrame(rows)
135
+
136
+ @torch.no_grad()
137
+ def prototype_report_global(
138
+ model: NousNet, X, y: Optional[Sequence[int]] = None, class_names: Optional[Sequence[str]] = None,
139
+ top_k_rules: int = 8, top_k_facts_per_rule: int = 2, trace_to_base: bool = True,
140
+ k_neighbors: int = 10, chunk_size: int = 2048, device=None
141
+ ):
142
+ """
143
+ Global prototype report:
144
+ - primary_class probabilities and entropy,
145
+ - average activation, top1 frequency,
146
+ - top rules and (optionally) tracing to base β-facts,
147
+ - nearest training examples and labels (if y provided).
148
+ """
149
+ import pandas as pd
150
+ assert isinstance(model.head, ScaledPrototypeLayer), "Prototypes head is not enabled."
151
+ device = device or next(model.parameters()).device
152
+ model.eval()
153
+
154
+ H = model.encode(X, device=device)
155
+ Hn = F.normalize(H, p=2, dim=1)
156
+ P = model.head.get_params()
157
+ Pn = P["prototypes_norm"]
158
+ Wsm = P["class_probs"].cpu().numpy()
159
+ Wraw = P["class_weights"].cpu().numpy()
160
+
161
+ M = Pn.shape[0]
162
+ N = Hn.shape[0]
163
+ top1_count = np.zeros(M, dtype=int)
164
+ mean_act = np.zeros(M, dtype=float)
165
+ neigh_idx: list[list[int]] = [[] for _ in range(M)]
166
+
167
+ for i in range(0, N, chunk_size):
168
+ h_chunk = Hn[i:i+chunk_size].to(device)
169
+ d = torch.cdist(h_chunk, Pn.to(device))
170
+ act = torch.exp(-F.softplus(model.head.temperature) * d).cpu().numpy()
171
+ top_idx = np.argmax(act, axis=1)
172
+ for t in top_idx:
173
+ top1_count[int(t)] += 1
174
+ mean_act += act.sum(axis=0)
175
+ k = min(k_neighbors, act.shape[0])
176
+ for m in range(M):
177
+ order = np.argsort(d[:, m].cpu().numpy())[:k]
178
+ neigh_idx[m].extend((i + order).tolist())
179
+
180
+ mean_act /= max(1, N)
181
+
182
+ agg_last, facts_last = get_last_block_static_metadata(model, top_k_facts_override=top_k_facts_per_rule)
183
+ fd = render_fact_descriptions(model, model.config['feature_names'])
184
+
185
+ def agg_row_to_str(aw_row):
186
+ if aw_row is None:
187
+ return "AND (fixed)"
188
+ A = int(aw_row.shape[0]) if hasattr(aw_row, "shape") else len(AGG_NAMES)
189
+ names = AGG_NAMES[:A]
190
+ parts = []
191
+ for i, a in enumerate(aw_row.tolist()):
192
+ if i < len(names) and a > 1e-6:
193
+ parts.append(f"{float(a):.2f} {names[i]}")
194
+ return " + ".join(parts) if parts else "∅"
195
+
196
+ rule_meta = []
197
+ for m in range(M):
198
+ top_rules = prototype_top_rules(model, m, top_k_rules=top_k_rules)
199
+ rules_desc = []
200
+ for (r, w) in top_rules:
201
+ r = int(r)
202
+ agg_str = "AND (fixed)"
203
+ if agg_last is not None and 0 <= r < agg_last.shape[0]:
204
+ aw = agg_last[r]
205
+ agg_str = agg_row_to_str(aw)
206
+ facts_str = ""
207
+ if trace_to_base:
208
+ base_facts = trace_rule_to_base_facts(model, r, top_k_per_step=top_k_facts_per_rule)
209
+ if base_facts:
210
+ facts_str = " | ".join([f"[F{int(fid)+1}] {fd[int(fid)]}" for fid in base_facts[:top_k_facts_per_rule]])
211
+ elif facts_last is not None and r < facts_last.shape[0]:
212
+ prev_units = facts_last[r]
213
+ facts_str = ", ".join([f"Unit{int(u)}" for u in prev_units])
214
+ rules_desc.append(dict(rule=int(r), weight=float(w), aggregators=agg_str, facts=facts_str))
215
+ rule_meta.append(rules_desc)
216
+
217
+ rows = []
218
+ for m in range(M):
219
+ primary = int(np.argmax(Wsm[m]))
220
+ rows.append({
221
+ "proto": m,
222
+ "primary_class": (class_names[primary] if class_names is not None else primary),
223
+ "class_probs": Wsm[m].tolist(),
224
+ "class_weights": Wraw[m].tolist(),
225
+ "class_entropy": float((-Wsm[m] * np.log(Wsm[m] + 1e-12)).sum()),
226
+ "mean_activation": float(mean_act[m]),
227
+ "top1_count": int(top1_count[m]),
228
+ "neighbors": neigh_idx[m][:k_neighbors],
229
+ "neighbor_labels": [int(y[idx]) for idx in neigh_idx[m][:k_neighbors]] if y is not None else None,
230
+ "top_rules_meta": rule_meta[m]
231
+ })
232
+ return pd.DataFrame(rows)
233
+
234
+ @torch.no_grad()
235
+ def describe_prototype(model: NousNet, proto_id: int, feature_names: Sequence[str], class_names: Optional[Sequence[str]] = None,
236
+ top_k_rules: int = 8, top_k_facts_per_rule: int = 2, trace_to_base: bool = True) -> str:
237
+ """
238
+ Human-readable description for a single prototype:
239
+ - class affinity (weights),
240
+ - top rules and aggregator mixtures,
241
+ - base β-facts (if traced).
242
+ """
243
+ assert isinstance(model.head, ScaledPrototypeLayer), "Prototypes head is not enabled."
244
+
245
+ P = model.head.get_params()
246
+ Wsm = P["class_probs"].cpu().numpy()
247
+ m = int(proto_id)
248
+ primary = int(np.argmax(Wsm[m]))
249
+ primary_name = class_names[primary] if class_names is not None else f"Class {primary}"
250
+
251
+ lines = []
252
+ lines.append(f"Prototype #{m}")
253
+ lines.append(f" Primary class: {primary_name} | probs={np.round(Wsm[m], 3).tolist()} | entropy={float((-Wsm[m]*np.log(Wsm[m]+1e-12)).sum()):.3f}")
254
+ lines.append(" Top rules:")
255
+ agg_last, _ = get_last_block_static_metadata(model, top_k_facts_override=top_k_facts_per_rule)
256
+ fd = render_fact_descriptions(model, feature_names)
257
+
258
+ def agg_row_to_str(aw_row):
259
+ if aw_row is None:
260
+ return "AND (fixed)"
261
+ A = int(aw_row.shape[0]) if hasattr(aw_row, "shape") else len(AGG_NAMES)
262
+ names = AGG_NAMES[:A]
263
+ parts = []
264
+ for i, a in enumerate(aw_row.tolist()):
265
+ if i < len(names) and a > 1e-6:
266
+ parts.append(f"{float(a):.2f} {names[i]}")
267
+ return " + ".join(parts) if parts else "∅"
268
+
269
+ for (r, w) in prototype_top_rules(model, m, top_k_rules=top_k_rules):
270
+ r = int(r)
271
+ agg_str = "AND (fixed)"
272
+ if agg_last is not None and 0 <= r < agg_last.shape[0]:
273
+ aw = agg_last[r]
274
+ agg_str = agg_row_to_str(aw)
275
+
276
+ facts_str = ""
277
+ if trace_to_base:
278
+ base_facts = trace_rule_to_base_facts(model, r, top_k_per_step=top_k_facts_per_rule)
279
+ if base_facts:
280
+ facts_str = " | ".join([f"[F{int(fid)+1}] {fd[int(fid)]}" for fid in base_facts[:top_k_facts_per_rule]])
281
+
282
+ lines.append(f" • R{r+1}: weight={w:+.3f} | {agg_str}")
283
+ if facts_str:
284
+ lines.append(f" {facts_str}")
285
+ return "\n".join(lines)
nous/explain/utils.py ADDED
@@ -0,0 +1,15 @@
1
+ from __future__ import annotations
2
+ import torch
3
+
4
+ def extract_active_masks(internals) -> list[torch.Tensor]:
5
+ """
6
+ Extract per-block active gate masks (0/1) from internals produced by forward_explain.
7
+ """
8
+ masks = []
9
+ keys = sorted([k for k in internals.keys() if k.startswith("block_")], key=lambda s: int(s.split("_")[1]))
10
+ for key in keys:
11
+ gate_mask = internals[key]['gate_mask']
12
+ if isinstance(gate_mask, torch.Tensor):
13
+ gate_mask = gate_mask.squeeze(0)
14
+ masks.append((gate_mask > 0).float())
15
+ return masks