nous 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nous might be problematic. Click here for more details.
- nous/__init__.py +96 -19
- nous/data/__init__.py +4 -0
- nous/data/california.py +32 -0
- nous/data/wine.py +29 -0
- nous/explain/__init__.py +26 -0
- nous/explain/aggregator.py +34 -0
- nous/explain/cf.py +137 -0
- nous/explain/facts_desc.py +23 -0
- nous/explain/fidelity.py +56 -0
- nous/explain/generate.py +86 -0
- nous/explain/global_book.py +52 -0
- nous/explain/loo.py +130 -0
- nous/explain/mse.py +93 -0
- nous/explain/pruning.py +117 -0
- nous/explain/stability.py +42 -0
- nous/explain/traces.py +285 -0
- nous/explain/utils.py +15 -0
- nous/export/__init__.py +13 -0
- nous/export/numpy_infer.py +412 -0
- nous/facts.py +112 -0
- nous/model.py +226 -0
- nous/prototypes.py +43 -0
- nous/rules/__init__.py +11 -0
- nous/rules/blocks.py +63 -0
- nous/rules/fixed.py +26 -0
- nous/rules/softmax.py +93 -0
- nous/rules/sparse.py +142 -0
- nous/training/__init__.py +5 -0
- nous/training/evaluation.py +57 -0
- nous/training/schedulers.py +34 -0
- nous/training/train.py +177 -0
- nous/types.py +4 -0
- nous/utils/__init__.py +3 -0
- nous/utils/metrics.py +2 -0
- nous/utils/seed.py +13 -0
- nous/version.py +1 -0
- nous-0.2.0.dist-info/METADATA +150 -0
- nous-0.2.0.dist-info/RECORD +41 -0
- nous/causal.py +0 -63
- nous/interpret.py +0 -111
- nous/layers.py +0 -117
- nous/models.py +0 -65
- nous-0.1.0.dist-info/METADATA +0 -138
- nous-0.1.0.dist-info/RECORD +0 -10
- {nous-0.1.0.dist-info → nous-0.2.0.dist-info}/WHEEL +0 -0
- {nous-0.1.0.dist-info → nous-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {nous-0.1.0.dist-info → nous-0.2.0.dist-info}/top_level.txt +0 -0
nous/explain/loo.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
4
|
+
import torch
|
|
5
|
+
from typing import Optional, Sequence
|
|
6
|
+
from ..model import NousNet
|
|
7
|
+
from .facts_desc import render_fact_descriptions
|
|
8
|
+
from .aggregator import format_agg_mixture
|
|
9
|
+
|
|
10
|
+
def rule_impact_df(
|
|
11
|
+
model: NousNet, x_sample, feature_names: Sequence[str], class_names: Optional[Sequence[str]] = None, vs_class: Optional[int] = None,
|
|
12
|
+
loo_mode: str = 'replacement', # 'replacement' | 'frozen' | 'active_only'
|
|
13
|
+
top_m_rules: Optional[int] = None,
|
|
14
|
+
use_pruning: bool = False, pruning_threshold: float = 0.0, use_pre_norm: bool = False,
|
|
15
|
+
base_state=None, device=None
|
|
16
|
+
) -> pd.DataFrame:
|
|
17
|
+
"""
|
|
18
|
+
Honest leave-one-out at the rule level with recomputed gating (before top-k),
|
|
19
|
+
supporting multiple modes (replacement, frozen, active_only).
|
|
20
|
+
"""
|
|
21
|
+
device = device or next(model.parameters()).device
|
|
22
|
+
task = model.config['task_type']
|
|
23
|
+
|
|
24
|
+
if base_state is None:
|
|
25
|
+
base_probas, base_logits, base_internals = model.forward_explain(
|
|
26
|
+
x_sample, apply_pruning=use_pruning, pruning_threshold=pruning_threshold, device=device
|
|
27
|
+
)
|
|
28
|
+
else:
|
|
29
|
+
base_probas, base_logits, base_internals = base_state
|
|
30
|
+
|
|
31
|
+
if task == "classification":
|
|
32
|
+
base_logits = np.array(base_logits)
|
|
33
|
+
pred_idx = int(np.argmax(base_probas))
|
|
34
|
+
if vs_class is None:
|
|
35
|
+
runner_up_idx = int(np.argsort(base_logits)[-2]) if base_logits.size > 1 else pred_idx
|
|
36
|
+
else:
|
|
37
|
+
runner_up_idx = int(vs_class)
|
|
38
|
+
base_margin = float(base_logits[pred_idx] - base_logits[runner_up_idx])
|
|
39
|
+
else:
|
|
40
|
+
base_pred = float(base_logits[0])
|
|
41
|
+
|
|
42
|
+
restrict_masks = None
|
|
43
|
+
if loo_mode in ('frozen', 'active_only'):
|
|
44
|
+
# Freeze current active set by extracting gate masks and reusing as restrict masks.
|
|
45
|
+
restrict_masks = []
|
|
46
|
+
block_keys = sorted([k for k in base_internals.keys() if k.startswith("block_")], key=lambda s: int(s.split("_")[1]))
|
|
47
|
+
for key in block_keys:
|
|
48
|
+
gm = base_internals[key]['gate_mask']
|
|
49
|
+
if isinstance(gm, torch.Tensor):
|
|
50
|
+
gm = gm.squeeze(0)
|
|
51
|
+
restrict_masks.append((gm > 0).float())
|
|
52
|
+
|
|
53
|
+
fact_desc = render_fact_descriptions(model, feature_names)
|
|
54
|
+
|
|
55
|
+
rows = []
|
|
56
|
+
block_keys = sorted([k for k in base_internals.keys() if k.startswith("block_")], key=lambda s: int(s.split("_")[1]))
|
|
57
|
+
for b_idx, key in enumerate(block_keys):
|
|
58
|
+
details = base_internals[key]
|
|
59
|
+
metric_tensor = details['pre_norm_sum'] if (use_pre_norm and 'pre_norm_sum' in details) else details['gated_activations']
|
|
60
|
+
ga_np = metric_tensor.squeeze(0).abs().cpu().numpy() if isinstance(metric_tensor, torch.Tensor) else np.abs(metric_tensor)
|
|
61
|
+
active_rules = np.where(ga_np > 1e-12)[0]
|
|
62
|
+
if len(active_rules) == 0:
|
|
63
|
+
continue
|
|
64
|
+
|
|
65
|
+
if top_m_rules is not None and len(active_rules) > top_m_rules:
|
|
66
|
+
order = np.argsort(-ga_np[active_rules])
|
|
67
|
+
active_rules = active_rules[order[:top_m_rules]]
|
|
68
|
+
|
|
69
|
+
agg_w = details['aggregator_weights']
|
|
70
|
+
if isinstance(agg_w, torch.Tensor):
|
|
71
|
+
agg_w = agg_w.cpu()
|
|
72
|
+
|
|
73
|
+
facts_used = details.get('facts_used', None)
|
|
74
|
+
if isinstance(facts_used, torch.Tensor):
|
|
75
|
+
facts_used = facts_used.cpu().numpy()
|
|
76
|
+
|
|
77
|
+
for r in active_rules:
|
|
78
|
+
drop_spec = (b_idx, int(r))
|
|
79
|
+
restr = None if loo_mode == 'replacement' else restrict_masks
|
|
80
|
+
|
|
81
|
+
drop_probas, drop_logits, _ = model.forward_explain(
|
|
82
|
+
x_sample, drop_rule_spec=drop_spec, restrict_masks=restr,
|
|
83
|
+
apply_pruning=use_pruning, pruning_threshold=pruning_threshold, device=device
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
if agg_w is not None:
|
|
87
|
+
aggs = format_agg_mixture(agg_w[r])
|
|
88
|
+
else:
|
|
89
|
+
aggs = "AND (fixed)"
|
|
90
|
+
|
|
91
|
+
facts_str = "—"
|
|
92
|
+
if facts_used is not None and facts_used.shape[0] > r:
|
|
93
|
+
used = facts_used[r]
|
|
94
|
+
if np.ndim(used) == 0:
|
|
95
|
+
used = [int(used)]
|
|
96
|
+
facts_str = ", ".join([f"F{int(fid)+1}" for fid in used])
|
|
97
|
+
facts_str += " → " + " | ".join([f"[F{int(fid)+1}] {fact_desc[int(fid)]}" for fid in used])
|
|
98
|
+
|
|
99
|
+
if task == "classification":
|
|
100
|
+
drop_logits = np.array(drop_logits)
|
|
101
|
+
dlogit = float(base_logits[pred_idx] - drop_logits[pred_idx])
|
|
102
|
+
drop_margin = float(drop_logits[pred_idx] - drop_logits[runner_up_idx])
|
|
103
|
+
dmargin = float(base_margin - drop_margin)
|
|
104
|
+
critical = (np.argmax(drop_probas) != pred_idx)
|
|
105
|
+
supports = dmargin > 0
|
|
106
|
+
rows.append({
|
|
107
|
+
"block": b_idx+1, "rule": int(r+1),
|
|
108
|
+
"aggregators": aggs, "facts": facts_str,
|
|
109
|
+
"Δlogit(pred)": dlogit,
|
|
110
|
+
f"Δmargin(vs {class_names[runner_up_idx] if class_names else runner_up_idx})": dmargin,
|
|
111
|
+
"supports_pred": supports, "critical_flip": critical
|
|
112
|
+
})
|
|
113
|
+
else:
|
|
114
|
+
drop_pred = float(drop_logits[0])
|
|
115
|
+
d_pred = base_pred - drop_pred
|
|
116
|
+
rows.append({
|
|
117
|
+
"block": b_idx+1, "rule": int(r+1),
|
|
118
|
+
"aggregators": aggs, "facts": facts_str,
|
|
119
|
+
"Δprediction": d_pred
|
|
120
|
+
})
|
|
121
|
+
|
|
122
|
+
df = pd.DataFrame(rows)
|
|
123
|
+
if not df.empty:
|
|
124
|
+
if task == "classification":
|
|
125
|
+
key = [c for c in df.columns if c.startswith("Δmargin(")][0]
|
|
126
|
+
df['abs_impact'] = df[key].abs()
|
|
127
|
+
else:
|
|
128
|
+
df['abs_impact'] = df['Δprediction'].abs()
|
|
129
|
+
df = df.sort_values(by='abs_impact', ascending=False).drop(columns=['abs_impact']).reset_index(drop=True)
|
|
130
|
+
return df
|
nous/explain/mse.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import numpy as np
|
|
3
|
+
import torch
|
|
4
|
+
from typing import Optional, Sequence, Dict, Any, List
|
|
5
|
+
from ..model import NousNet
|
|
6
|
+
from .loo import rule_impact_df
|
|
7
|
+
|
|
8
|
+
def minimal_sufficient_explanation(
|
|
9
|
+
model: NousNet, x_sample, feature_names: Sequence[str], class_names: Optional[Sequence[str]] = None,
|
|
10
|
+
margin_tolerance: float = 0.0, prob_tolerance: float = 0.0, pred_tolerance_reg: float = 0.02,
|
|
11
|
+
freeze_non_active: bool = True, device=None
|
|
12
|
+
) -> Dict[str, Any]:
|
|
13
|
+
"""
|
|
14
|
+
Greedy backward elimination among originally active rules.
|
|
15
|
+
Preserves the prediction (class unchanged with margins within tolerance; regression within abs tol).
|
|
16
|
+
"""
|
|
17
|
+
device = device or next(model.parameters()).device
|
|
18
|
+
task = model.config['task_type']
|
|
19
|
+
|
|
20
|
+
base_probas, base_logits, base_internals = model.forward_explain(x_sample, device=device)
|
|
21
|
+
if task == "classification":
|
|
22
|
+
pred_idx = int(np.argmax(base_probas))
|
|
23
|
+
runner_up = int(np.argsort(base_logits)[-2]) if base_logits.size > 1 else pred_idx
|
|
24
|
+
base_margin = float(base_logits[pred_idx] - base_logits[runner_up])
|
|
25
|
+
base_conf = float(base_probas[pred_idx])
|
|
26
|
+
else:
|
|
27
|
+
base_pred = float(base_logits[0])
|
|
28
|
+
|
|
29
|
+
# Extract active set (frozen)
|
|
30
|
+
active_masks: List[torch.Tensor] = []
|
|
31
|
+
block_keys = sorted([k for k in base_internals.keys() if k.startswith("block_")], key=lambda s: int(s.split("_")[1]))
|
|
32
|
+
for key in block_keys:
|
|
33
|
+
gm = base_internals[key]['gate_mask']
|
|
34
|
+
if isinstance(gm, torch.Tensor):
|
|
35
|
+
gm = gm.squeeze(0)
|
|
36
|
+
active_masks.append((gm > 0).float())
|
|
37
|
+
current_masks = [m.clone() for m in active_masks]
|
|
38
|
+
|
|
39
|
+
# Rank removal candidates by (absolute) impact ascending
|
|
40
|
+
imp = rule_impact_df(
|
|
41
|
+
model, x_sample, feature_names, class_names=class_names,
|
|
42
|
+
loo_mode='frozen', device=device
|
|
43
|
+
)
|
|
44
|
+
if imp.empty:
|
|
45
|
+
return {
|
|
46
|
+
"kept_masks": current_masks, "kept": [],
|
|
47
|
+
"removed": [], "pred_preserved": True,
|
|
48
|
+
"size": sum(int(m.sum().item()) for m in current_masks)
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
if task == "classification":
|
|
52
|
+
margin_col = [c for c in imp.columns if c.startswith("Δmargin(")][0]
|
|
53
|
+
imp = imp.assign(order=imp[margin_col].abs()).sort_values("order", ascending=True).drop(columns=["order"])
|
|
54
|
+
else:
|
|
55
|
+
imp = imp.assign(order=imp["Δprediction"].abs()).sort_values("order", ascending=True).drop(columns=["order"])
|
|
56
|
+
|
|
57
|
+
removed = []
|
|
58
|
+
for _, row in imp.iterrows():
|
|
59
|
+
b = int(row["block"]) - 1
|
|
60
|
+
r = int(row["rule"]) - 1
|
|
61
|
+
if current_masks[b][r] == 0:
|
|
62
|
+
continue
|
|
63
|
+
|
|
64
|
+
proposal = [m.clone() for m in current_masks]
|
|
65
|
+
proposal[b][r] = 0.0
|
|
66
|
+
probas2, logits2, _ = model.forward_explain(x_sample, restrict_masks=proposal, device=device)
|
|
67
|
+
|
|
68
|
+
if task == "classification":
|
|
69
|
+
new_pred = int(np.argmax(probas2))
|
|
70
|
+
if new_pred != pred_idx:
|
|
71
|
+
continue
|
|
72
|
+
new_margin = float(logits2[pred_idx] - logits2[runner_up])
|
|
73
|
+
new_conf = float(probas2[pred_idx])
|
|
74
|
+
if new_margin < base_margin - margin_tolerance:
|
|
75
|
+
continue
|
|
76
|
+
if prob_tolerance > 0.0 and (new_conf < base_conf - prob_tolerance):
|
|
77
|
+
continue
|
|
78
|
+
else:
|
|
79
|
+
new_pred = float(logits2[0])
|
|
80
|
+
if abs(new_pred - base_pred) > pred_tolerance_reg:
|
|
81
|
+
continue
|
|
82
|
+
|
|
83
|
+
current_masks = proposal
|
|
84
|
+
removed.append((b+1, r+1))
|
|
85
|
+
|
|
86
|
+
kept = [(i+1, int(idx.item())+1) for i, m in enumerate(current_masks) for idx in torch.where(m > 0)[0]]
|
|
87
|
+
return {
|
|
88
|
+
"kept_masks": current_masks,
|
|
89
|
+
"kept": kept,
|
|
90
|
+
"removed": removed,
|
|
91
|
+
"pred_preserved": True,
|
|
92
|
+
"size": sum(int(m.sum().item()) for m in current_masks)
|
|
93
|
+
}
|
nous/explain/pruning.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import numpy as np
|
|
3
|
+
from typing import Optional
|
|
4
|
+
from sklearn.metrics import mean_absolute_error
|
|
5
|
+
from ..model import NousNet
|
|
6
|
+
|
|
7
|
+
def select_pruning_threshold_global(
|
|
8
|
+
model: NousNet, X_val, target_fidelity: float = 0.99, task_type: Optional[str] = None,
|
|
9
|
+
metric_reg: str = "mae", tol_reg: float = 0.05, pointwise_tol: Optional[float] = None,
|
|
10
|
+
max_samples: int = 500, device=None
|
|
11
|
+
) -> float:
|
|
12
|
+
"""
|
|
13
|
+
Grid selection of activation threshold (post-gating) keeping fidelity (classification) or MAE (regression).
|
|
14
|
+
"""
|
|
15
|
+
device = device or next(model.parameters()).device
|
|
16
|
+
task = task_type or model.config['task_type']
|
|
17
|
+
|
|
18
|
+
n = min(len(X_val), max_samples)
|
|
19
|
+
Xv = X_val[:n]
|
|
20
|
+
|
|
21
|
+
base_preds = []
|
|
22
|
+
acts = []
|
|
23
|
+
for i in range(n):
|
|
24
|
+
_, logit_b, internals = model.forward_explain(Xv[i], device=device)
|
|
25
|
+
base_preds.append(int(np.argmax(logit_b)) if task == "classification" else float(logit_b[0]))
|
|
26
|
+
for key in [k for k in internals.keys() if k.startswith("block_")]:
|
|
27
|
+
ga = internals[key]['gated_activations'].abs().cpu().numpy().ravel()
|
|
28
|
+
acts.extend(list(ga))
|
|
29
|
+
acts = np.array(acts)
|
|
30
|
+
if acts.size == 0:
|
|
31
|
+
return 0.0
|
|
32
|
+
|
|
33
|
+
qs = np.linspace(0.5, 0.999, 25)
|
|
34
|
+
candidates = np.unique(np.quantile(acts, qs))
|
|
35
|
+
|
|
36
|
+
best_t = 0.0
|
|
37
|
+
for t in candidates:
|
|
38
|
+
if task == "classification":
|
|
39
|
+
agree = 0
|
|
40
|
+
for i in range(n):
|
|
41
|
+
_, logit_p, _ = model.forward_explain(Xv[i], apply_pruning=True, pruning_threshold=float(t), device=device)
|
|
42
|
+
agree += int(int(np.argmax(logit_p)) == base_preds[i])
|
|
43
|
+
fidelity = agree / n
|
|
44
|
+
if fidelity >= target_fidelity:
|
|
45
|
+
best_t = float(t)
|
|
46
|
+
else:
|
|
47
|
+
preds_p, preds_b = [], []
|
|
48
|
+
max_abs = 0.0
|
|
49
|
+
for i in range(n):
|
|
50
|
+
_, logit_p, _ = model.forward_explain(Xv[i], apply_pruning=True, pruning_threshold=float(t), device=device)
|
|
51
|
+
_, logit_b, _ = model.forward_explain(Xv[i], device=device)
|
|
52
|
+
pv, bv = float(logit_p[0]), float(logit_b[0])
|
|
53
|
+
preds_p.append(pv); preds_b.append(bv)
|
|
54
|
+
max_abs = max(max_abs, abs(pv - bv))
|
|
55
|
+
mae_p = mean_absolute_error(preds_b, preds_p)
|
|
56
|
+
ok_mae = (mae_p <= tol_reg)
|
|
57
|
+
ok_point = True if pointwise_tol is None else (max_abs <= pointwise_tol)
|
|
58
|
+
if ok_mae and ok_point:
|
|
59
|
+
best_t = float(t)
|
|
60
|
+
return best_t
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def select_pruning_threshold_global_bs(
|
|
64
|
+
model: NousNet, X_val, target_fidelity: float = 0.99, task_type: Optional[str] = None,
|
|
65
|
+
metric_reg: str = "mae", tol_reg: float = 0.05, pointwise_tol: Optional[float] = None,
|
|
66
|
+
max_samples: int = 500, device=None
|
|
67
|
+
) -> float:
|
|
68
|
+
"""
|
|
69
|
+
Binary search selection of activation threshold (post-gating) with fidelity/MAE constraints.
|
|
70
|
+
"""
|
|
71
|
+
device = device or next(model.parameters()).device
|
|
72
|
+
task = task_type or model.config['task_type']
|
|
73
|
+
n = min(len(X_val), max_samples)
|
|
74
|
+
Xv = X_val[:n]
|
|
75
|
+
|
|
76
|
+
base_preds = []
|
|
77
|
+
acts = []
|
|
78
|
+
for i in range(n):
|
|
79
|
+
_, logit_b, internals = model.forward_explain(Xv[i], device=device)
|
|
80
|
+
base_preds.append(int(np.argmax(logit_b)) if task == "classification" else float(logit_b[0]))
|
|
81
|
+
for key in [k for k in internals.keys() if k.startswith("block_")]:
|
|
82
|
+
ga = internals[key]['gated_activations'].abs().cpu().numpy().ravel()
|
|
83
|
+
acts.extend(list(ga))
|
|
84
|
+
acts = np.array(acts)
|
|
85
|
+
if acts.size == 0:
|
|
86
|
+
return 0.0
|
|
87
|
+
|
|
88
|
+
lo, hi = 0.0, float(np.quantile(acts, 0.999))
|
|
89
|
+
best_t = 0.0
|
|
90
|
+
for _ in range(14):
|
|
91
|
+
mid = (lo + hi) / 2.0
|
|
92
|
+
if task == "classification":
|
|
93
|
+
agree = 0
|
|
94
|
+
for i in range(n):
|
|
95
|
+
_, logit_p, _ = model.forward_explain(Xv[i], apply_pruning=True, pruning_threshold=float(mid), device=device)
|
|
96
|
+
agree += int(int(np.argmax(logit_p)) == base_preds[i])
|
|
97
|
+
fidelity = agree / n
|
|
98
|
+
if fidelity >= target_fidelity:
|
|
99
|
+
best_t = mid; lo = mid
|
|
100
|
+
else:
|
|
101
|
+
hi = mid
|
|
102
|
+
else:
|
|
103
|
+
preds_p, preds_b, max_abs = [], [], 0.0
|
|
104
|
+
for i in range(n):
|
|
105
|
+
_, logit_p, _ = model.forward_explain(Xv[i], apply_pruning=True, pruning_threshold=float(mid), device=device)
|
|
106
|
+
_, logit_b, _ = model.forward_explain(Xv[i], device=device)
|
|
107
|
+
pv, bv = float(logit_p[0]), float(logit_b[0])
|
|
108
|
+
preds_p.append(pv); preds_b.append(bv)
|
|
109
|
+
max_abs = max(max_abs, abs(pv - bv))
|
|
110
|
+
mae_p = mean_absolute_error(preds_b, preds_p)
|
|
111
|
+
ok_mae = (mae_p <= tol_reg)
|
|
112
|
+
ok_point = True if pointwise_tol is None else (max_abs <= pointwise_tol)
|
|
113
|
+
if ok_mae and ok_point:
|
|
114
|
+
best_t = mid; lo = mid
|
|
115
|
+
else:
|
|
116
|
+
hi = mid
|
|
117
|
+
return float(best_t)
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import numpy as np
|
|
3
|
+
import torch
|
|
4
|
+
from typing import Optional, Sequence, Dict, Any
|
|
5
|
+
from ..model import NousNet
|
|
6
|
+
from .loo import rule_impact_df
|
|
7
|
+
|
|
8
|
+
def explanation_stability(
|
|
9
|
+
model: NousNet, x_sample, feature_names: Sequence[str], class_names: Optional[Sequence[str]] = None,
|
|
10
|
+
k_top: int = 5, sigma: float = 0.05, trials: int = 20, loo_mode: str = 'frozen',
|
|
11
|
+
use_pruning: bool = False, pruning_threshold: float = 0.0, device=None
|
|
12
|
+
) -> Dict[str, Any]:
|
|
13
|
+
"""
|
|
14
|
+
Stability of top-k rule explanations under small input perturbations.
|
|
15
|
+
Returns mean/std Jaccard overlap vs base top-k selection.
|
|
16
|
+
"""
|
|
17
|
+
device = device or next(model.parameters()).device
|
|
18
|
+
base_imp = rule_impact_df(
|
|
19
|
+
model, x_sample, feature_names, class_names=class_names,
|
|
20
|
+
loo_mode=loo_mode, use_pruning=use_pruning, pruning_threshold=pruning_threshold
|
|
21
|
+
)
|
|
22
|
+
base_set = set((int(r.block), int(r.rule)) for _, r in base_imp.head(k_top).iterrows()) if not base_imp.empty else set()
|
|
23
|
+
|
|
24
|
+
overlaps = []
|
|
25
|
+
for _ in range(trials):
|
|
26
|
+
x = torch.tensor(x_sample, dtype=torch.float32, device=device)
|
|
27
|
+
noise = torch.randn_like(x) * sigma
|
|
28
|
+
x_noisy = (x + noise).cpu().numpy()
|
|
29
|
+
|
|
30
|
+
imp = rule_impact_df(
|
|
31
|
+
model, x_noisy, feature_names, class_names=class_names,
|
|
32
|
+
loo_mode=loo_mode, use_pruning=use_pruning, pruning_threshold=pruning_threshold
|
|
33
|
+
)
|
|
34
|
+
cur_set = set((int(r.block), int(r.rule)) for _, r in imp.head(k_top).iterrows()) if not imp.empty else set()
|
|
35
|
+
inter = len(base_set & cur_set)
|
|
36
|
+
union = len(base_set | cur_set) if (base_set | cur_set) else 1
|
|
37
|
+
overlaps.append(inter / union)
|
|
38
|
+
return {
|
|
39
|
+
"mean_jaccard": float(np.mean(overlaps)),
|
|
40
|
+
"std_jaccard": float(np.std(overlaps)),
|
|
41
|
+
"base_top_rules": list(base_set)
|
|
42
|
+
}
|
nous/explain/traces.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import numpy as np
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from typing import Optional, Sequence, List, Dict, Any
|
|
6
|
+
from ..model import NousNet
|
|
7
|
+
from ..prototypes import ScaledPrototypeLayer
|
|
8
|
+
from ..explain.aggregator import AGG_NAMES
|
|
9
|
+
from ..explain.facts_desc import render_fact_descriptions
|
|
10
|
+
|
|
11
|
+
@torch.no_grad()
|
|
12
|
+
def get_last_block_static_metadata(model: NousNet, top_k_facts_override: Optional[int] = None):
|
|
13
|
+
"""
|
|
14
|
+
Returns:
|
|
15
|
+
- agg_weights: torch.Tensor [R, A] or None
|
|
16
|
+
- facts_used: np.ndarray [R, K] (indices of inputs to the last block) or None
|
|
17
|
+
"""
|
|
18
|
+
blk = model.blocks[-1]
|
|
19
|
+
agg = None
|
|
20
|
+
facts_used = None
|
|
21
|
+
|
|
22
|
+
if hasattr(blk, "aggregator_logits"):
|
|
23
|
+
agg = F.softmax(blk.aggregator_logits, dim=1).detach().cpu()
|
|
24
|
+
|
|
25
|
+
# Fact indices used per rule depend on block type
|
|
26
|
+
from ..rules.blocks import SimpleNousBlock
|
|
27
|
+
from ..rules.softmax import SoftmaxRuleLayer
|
|
28
|
+
from ..rules.sparse import SparseRuleLayer
|
|
29
|
+
|
|
30
|
+
if isinstance(blk, SimpleNousBlock):
|
|
31
|
+
facts_used = blk.rule.idx.detach().cpu().numpy() # [R, 2]
|
|
32
|
+
elif isinstance(blk, SoftmaxRuleLayer):
|
|
33
|
+
k = top_k_facts_override or getattr(blk, "top_k_facts", 2)
|
|
34
|
+
fl = F.softmax(blk.fact_logits, dim=1)
|
|
35
|
+
_, topk = torch.topk(fl, k=min(k, blk.input_dim), dim=1)
|
|
36
|
+
facts_used = topk.detach().cpu().numpy() # [R, k]
|
|
37
|
+
elif isinstance(blk, SparseRuleLayer):
|
|
38
|
+
k = top_k_facts_override or getattr(blk, "top_k_facts", 2)
|
|
39
|
+
prob = blk.hard_concrete.get_proba().detach()
|
|
40
|
+
_, topk = torch.topk(prob, k=min(k, blk.input_dim), dim=1)
|
|
41
|
+
facts_used = topk.cpu().numpy()
|
|
42
|
+
return agg, facts_used
|
|
43
|
+
|
|
44
|
+
@torch.no_grad()
|
|
45
|
+
def _block_fact_mapping(block, top_k_per_rule: int = 2):
|
|
46
|
+
"""
|
|
47
|
+
Returns np.ndarray [R, k] — input indices of the given block for each rule.
|
|
48
|
+
"""
|
|
49
|
+
from ..rules.blocks import SimpleNousBlock
|
|
50
|
+
from ..rules.softmax import SoftmaxRuleLayer
|
|
51
|
+
from ..rules.sparse import SparseRuleLayer
|
|
52
|
+
|
|
53
|
+
if isinstance(block, SimpleNousBlock):
|
|
54
|
+
return block.rule.idx.detach().cpu().numpy()
|
|
55
|
+
elif isinstance(block, SoftmaxRuleLayer):
|
|
56
|
+
fl = F.softmax(block.fact_logits, dim=1)
|
|
57
|
+
_, topk = torch.topk(fl, k=min(top_k_per_rule, block.input_dim), dim=1)
|
|
58
|
+
return topk.detach().cpu().numpy()
|
|
59
|
+
elif isinstance(block, SparseRuleLayer):
|
|
60
|
+
prob = block.hard_concrete.get_proba().detach()
|
|
61
|
+
_, topk = torch.topk(prob, k=min(top_k_per_rule, block.input_dim), dim=1)
|
|
62
|
+
return topk.cpu().numpy()
|
|
63
|
+
else:
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
@torch.no_grad()
|
|
67
|
+
def trace_rule_to_base_facts(model: NousNet, rule_idx_last: int, top_k_per_step: int = 2) -> List[int]:
|
|
68
|
+
"""
|
|
69
|
+
Trace a rule from the last block down to base β-facts.
|
|
70
|
+
"""
|
|
71
|
+
selected = {int(rule_idx_last)}
|
|
72
|
+
for b in reversed(range(len(model.blocks))):
|
|
73
|
+
blk = model.blocks[b]
|
|
74
|
+
mapping = _block_fact_mapping(blk, top_k_per_rule=top_k_per_step)
|
|
75
|
+
if mapping is None:
|
|
76
|
+
return []
|
|
77
|
+
prev_sel = set()
|
|
78
|
+
for r in selected:
|
|
79
|
+
if r < 0 or r >= mapping.shape[0]:
|
|
80
|
+
continue
|
|
81
|
+
prev_sel.update([int(i) for i in mapping[r].tolist()])
|
|
82
|
+
selected = prev_sel
|
|
83
|
+
return sorted(selected)
|
|
84
|
+
|
|
85
|
+
@torch.no_grad()
|
|
86
|
+
def prototype_top_rules(model: NousNet, proto_id: int, top_k_rules: int = 10):
|
|
87
|
+
"""
|
|
88
|
+
Returns list of (rule_idx, weight) from the prototype vector sorted by |weight|.
|
|
89
|
+
"""
|
|
90
|
+
assert isinstance(model.head, ScaledPrototypeLayer), "Prototypes head is not enabled."
|
|
91
|
+
P = model.head.prototypes.detach().cpu().numpy()
|
|
92
|
+
v = P[proto_id]
|
|
93
|
+
idx = np.argsort(-np.abs(v))[:top_k_rules]
|
|
94
|
+
return [(int(i), float(v[i])) for i in idx]
|
|
95
|
+
|
|
96
|
+
@torch.no_grad()
|
|
97
|
+
def prototype_contribution_df(model: NousNet, x_sample, class_names: Optional[Sequence[str]] = None, top_k: int = 5, device=None):
|
|
98
|
+
"""
|
|
99
|
+
For a single example — contribution of prototypes to the predicted class.
|
|
100
|
+
"""
|
|
101
|
+
import pandas as pd
|
|
102
|
+
|
|
103
|
+
assert isinstance(model.head, ScaledPrototypeLayer), "Prototypes head is not enabled."
|
|
104
|
+
device = device or next(model.parameters()).device
|
|
105
|
+
|
|
106
|
+
h = model.encode(np.array([x_sample]), device=device)
|
|
107
|
+
d, act = model.head.compute_dist_act(h.to(device))
|
|
108
|
+
act = act.squeeze(0).cpu().numpy()
|
|
109
|
+
d = d.squeeze(0).cpu().numpy()
|
|
110
|
+
|
|
111
|
+
W = model.head.prototype_class.detach().cpu().numpy() # [M, C]
|
|
112
|
+
probs, logits, _ = model.forward_explain(x_sample, device=device)
|
|
113
|
+
c_hat = int(np.argmax(probs))
|
|
114
|
+
w_c = W[:, c_hat]
|
|
115
|
+
contrib = act * w_c
|
|
116
|
+
|
|
117
|
+
Wsm = F.softmax(model.head.prototype_class, dim=1).detach().cpu().numpy()
|
|
118
|
+
ent = (-Wsm * np.clip(np.log(Wsm + 1e-12), -50, 50)).sum(axis=1)
|
|
119
|
+
primary = np.argmax(Wsm, axis=1)
|
|
120
|
+
|
|
121
|
+
order = np.argsort(-np.abs(contrib))[:top_k]
|
|
122
|
+
rows = []
|
|
123
|
+
for j in order:
|
|
124
|
+
rows.append({
|
|
125
|
+
"proto": int(j),
|
|
126
|
+
"distance": float(d[j]),
|
|
127
|
+
"activation": float(act[j]),
|
|
128
|
+
"w_c": float(w_c[j]),
|
|
129
|
+
"contribution": float(contrib[j]),
|
|
130
|
+
"primary_class": (class_names[primary[j]] if class_names is not None else int(primary[j])),
|
|
131
|
+
"entropy": float(ent[j])
|
|
132
|
+
})
|
|
133
|
+
import pandas as pd
|
|
134
|
+
return pd.DataFrame(rows)
|
|
135
|
+
|
|
136
|
+
@torch.no_grad()
|
|
137
|
+
def prototype_report_global(
|
|
138
|
+
model: NousNet, X, y: Optional[Sequence[int]] = None, class_names: Optional[Sequence[str]] = None,
|
|
139
|
+
top_k_rules: int = 8, top_k_facts_per_rule: int = 2, trace_to_base: bool = True,
|
|
140
|
+
k_neighbors: int = 10, chunk_size: int = 2048, device=None
|
|
141
|
+
):
|
|
142
|
+
"""
|
|
143
|
+
Global prototype report:
|
|
144
|
+
- primary_class probabilities and entropy,
|
|
145
|
+
- average activation, top1 frequency,
|
|
146
|
+
- top rules and (optionally) tracing to base β-facts,
|
|
147
|
+
- nearest training examples and labels (if y provided).
|
|
148
|
+
"""
|
|
149
|
+
import pandas as pd
|
|
150
|
+
assert isinstance(model.head, ScaledPrototypeLayer), "Prototypes head is not enabled."
|
|
151
|
+
device = device or next(model.parameters()).device
|
|
152
|
+
model.eval()
|
|
153
|
+
|
|
154
|
+
H = model.encode(X, device=device)
|
|
155
|
+
Hn = F.normalize(H, p=2, dim=1)
|
|
156
|
+
P = model.head.get_params()
|
|
157
|
+
Pn = P["prototypes_norm"]
|
|
158
|
+
Wsm = P["class_probs"].cpu().numpy()
|
|
159
|
+
Wraw = P["class_weights"].cpu().numpy()
|
|
160
|
+
|
|
161
|
+
M = Pn.shape[0]
|
|
162
|
+
N = Hn.shape[0]
|
|
163
|
+
top1_count = np.zeros(M, dtype=int)
|
|
164
|
+
mean_act = np.zeros(M, dtype=float)
|
|
165
|
+
neigh_idx: list[list[int]] = [[] for _ in range(M)]
|
|
166
|
+
|
|
167
|
+
for i in range(0, N, chunk_size):
|
|
168
|
+
h_chunk = Hn[i:i+chunk_size].to(device)
|
|
169
|
+
d = torch.cdist(h_chunk, Pn.to(device))
|
|
170
|
+
act = torch.exp(-F.softplus(model.head.temperature) * d).cpu().numpy()
|
|
171
|
+
top_idx = np.argmax(act, axis=1)
|
|
172
|
+
for t in top_idx:
|
|
173
|
+
top1_count[int(t)] += 1
|
|
174
|
+
mean_act += act.sum(axis=0)
|
|
175
|
+
k = min(k_neighbors, act.shape[0])
|
|
176
|
+
for m in range(M):
|
|
177
|
+
order = np.argsort(d[:, m].cpu().numpy())[:k]
|
|
178
|
+
neigh_idx[m].extend((i + order).tolist())
|
|
179
|
+
|
|
180
|
+
mean_act /= max(1, N)
|
|
181
|
+
|
|
182
|
+
agg_last, facts_last = get_last_block_static_metadata(model, top_k_facts_override=top_k_facts_per_rule)
|
|
183
|
+
fd = render_fact_descriptions(model, model.config['feature_names'])
|
|
184
|
+
|
|
185
|
+
def agg_row_to_str(aw_row):
|
|
186
|
+
if aw_row is None:
|
|
187
|
+
return "AND (fixed)"
|
|
188
|
+
A = int(aw_row.shape[0]) if hasattr(aw_row, "shape") else len(AGG_NAMES)
|
|
189
|
+
names = AGG_NAMES[:A]
|
|
190
|
+
parts = []
|
|
191
|
+
for i, a in enumerate(aw_row.tolist()):
|
|
192
|
+
if i < len(names) and a > 1e-6:
|
|
193
|
+
parts.append(f"{float(a):.2f} {names[i]}")
|
|
194
|
+
return " + ".join(parts) if parts else "∅"
|
|
195
|
+
|
|
196
|
+
rule_meta = []
|
|
197
|
+
for m in range(M):
|
|
198
|
+
top_rules = prototype_top_rules(model, m, top_k_rules=top_k_rules)
|
|
199
|
+
rules_desc = []
|
|
200
|
+
for (r, w) in top_rules:
|
|
201
|
+
r = int(r)
|
|
202
|
+
agg_str = "AND (fixed)"
|
|
203
|
+
if agg_last is not None and 0 <= r < agg_last.shape[0]:
|
|
204
|
+
aw = agg_last[r]
|
|
205
|
+
agg_str = agg_row_to_str(aw)
|
|
206
|
+
facts_str = ""
|
|
207
|
+
if trace_to_base:
|
|
208
|
+
base_facts = trace_rule_to_base_facts(model, r, top_k_per_step=top_k_facts_per_rule)
|
|
209
|
+
if base_facts:
|
|
210
|
+
facts_str = " | ".join([f"[F{int(fid)+1}] {fd[int(fid)]}" for fid in base_facts[:top_k_facts_per_rule]])
|
|
211
|
+
elif facts_last is not None and r < facts_last.shape[0]:
|
|
212
|
+
prev_units = facts_last[r]
|
|
213
|
+
facts_str = ", ".join([f"Unit{int(u)}" for u in prev_units])
|
|
214
|
+
rules_desc.append(dict(rule=int(r), weight=float(w), aggregators=agg_str, facts=facts_str))
|
|
215
|
+
rule_meta.append(rules_desc)
|
|
216
|
+
|
|
217
|
+
rows = []
|
|
218
|
+
for m in range(M):
|
|
219
|
+
primary = int(np.argmax(Wsm[m]))
|
|
220
|
+
rows.append({
|
|
221
|
+
"proto": m,
|
|
222
|
+
"primary_class": (class_names[primary] if class_names is not None else primary),
|
|
223
|
+
"class_probs": Wsm[m].tolist(),
|
|
224
|
+
"class_weights": Wraw[m].tolist(),
|
|
225
|
+
"class_entropy": float((-Wsm[m] * np.log(Wsm[m] + 1e-12)).sum()),
|
|
226
|
+
"mean_activation": float(mean_act[m]),
|
|
227
|
+
"top1_count": int(top1_count[m]),
|
|
228
|
+
"neighbors": neigh_idx[m][:k_neighbors],
|
|
229
|
+
"neighbor_labels": [int(y[idx]) for idx in neigh_idx[m][:k_neighbors]] if y is not None else None,
|
|
230
|
+
"top_rules_meta": rule_meta[m]
|
|
231
|
+
})
|
|
232
|
+
return pd.DataFrame(rows)
|
|
233
|
+
|
|
234
|
+
@torch.no_grad()
|
|
235
|
+
def describe_prototype(model: NousNet, proto_id: int, feature_names: Sequence[str], class_names: Optional[Sequence[str]] = None,
|
|
236
|
+
top_k_rules: int = 8, top_k_facts_per_rule: int = 2, trace_to_base: bool = True) -> str:
|
|
237
|
+
"""
|
|
238
|
+
Human-readable description for a single prototype:
|
|
239
|
+
- class affinity (weights),
|
|
240
|
+
- top rules and aggregator mixtures,
|
|
241
|
+
- base β-facts (if traced).
|
|
242
|
+
"""
|
|
243
|
+
assert isinstance(model.head, ScaledPrototypeLayer), "Prototypes head is not enabled."
|
|
244
|
+
|
|
245
|
+
P = model.head.get_params()
|
|
246
|
+
Wsm = P["class_probs"].cpu().numpy()
|
|
247
|
+
m = int(proto_id)
|
|
248
|
+
primary = int(np.argmax(Wsm[m]))
|
|
249
|
+
primary_name = class_names[primary] if class_names is not None else f"Class {primary}"
|
|
250
|
+
|
|
251
|
+
lines = []
|
|
252
|
+
lines.append(f"Prototype #{m}")
|
|
253
|
+
lines.append(f" Primary class: {primary_name} | probs={np.round(Wsm[m], 3).tolist()} | entropy={float((-Wsm[m]*np.log(Wsm[m]+1e-12)).sum()):.3f}")
|
|
254
|
+
lines.append(" Top rules:")
|
|
255
|
+
agg_last, _ = get_last_block_static_metadata(model, top_k_facts_override=top_k_facts_per_rule)
|
|
256
|
+
fd = render_fact_descriptions(model, feature_names)
|
|
257
|
+
|
|
258
|
+
def agg_row_to_str(aw_row):
|
|
259
|
+
if aw_row is None:
|
|
260
|
+
return "AND (fixed)"
|
|
261
|
+
A = int(aw_row.shape[0]) if hasattr(aw_row, "shape") else len(AGG_NAMES)
|
|
262
|
+
names = AGG_NAMES[:A]
|
|
263
|
+
parts = []
|
|
264
|
+
for i, a in enumerate(aw_row.tolist()):
|
|
265
|
+
if i < len(names) and a > 1e-6:
|
|
266
|
+
parts.append(f"{float(a):.2f} {names[i]}")
|
|
267
|
+
return " + ".join(parts) if parts else "∅"
|
|
268
|
+
|
|
269
|
+
for (r, w) in prototype_top_rules(model, m, top_k_rules=top_k_rules):
|
|
270
|
+
r = int(r)
|
|
271
|
+
agg_str = "AND (fixed)"
|
|
272
|
+
if agg_last is not None and 0 <= r < agg_last.shape[0]:
|
|
273
|
+
aw = agg_last[r]
|
|
274
|
+
agg_str = agg_row_to_str(aw)
|
|
275
|
+
|
|
276
|
+
facts_str = ""
|
|
277
|
+
if trace_to_base:
|
|
278
|
+
base_facts = trace_rule_to_base_facts(model, r, top_k_per_step=top_k_facts_per_rule)
|
|
279
|
+
if base_facts:
|
|
280
|
+
facts_str = " | ".join([f"[F{int(fid)+1}] {fd[int(fid)]}" for fid in base_facts[:top_k_facts_per_rule]])
|
|
281
|
+
|
|
282
|
+
lines.append(f" • R{r+1}: weight={w:+.3f} | {agg_str}")
|
|
283
|
+
if facts_str:
|
|
284
|
+
lines.append(f" {facts_str}")
|
|
285
|
+
return "\n".join(lines)
|
nous/explain/utils.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
def extract_active_masks(internals) -> list[torch.Tensor]:
|
|
5
|
+
"""
|
|
6
|
+
Extract per-block active gate masks (0/1) from internals produced by forward_explain.
|
|
7
|
+
"""
|
|
8
|
+
masks = []
|
|
9
|
+
keys = sorted([k for k in internals.keys() if k.startswith("block_")], key=lambda s: int(s.split("_")[1]))
|
|
10
|
+
for key in keys:
|
|
11
|
+
gate_mask = internals[key]['gate_mask']
|
|
12
|
+
if isinstance(gate_mask, torch.Tensor):
|
|
13
|
+
gate_mask = gate_mask.squeeze(0)
|
|
14
|
+
masks.append((gate_mask > 0).float())
|
|
15
|
+
return masks
|