nous 0.0.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of nous might be problematic. Click here for more details.

nous/__init__.py CHANGED
@@ -1,2 +1,103 @@
1
- # nous/__init__.py
2
- __version__ = "0.0.1"
1
+ from .version import __version__
2
+ from .model import NousNet
3
+ from .facts import BetaFactLayer, PiecewiseLinearCalibrator
4
+ from .prototypes import ScaledPrototypeLayer
5
+ from .rules import FixedPairRuleLayer, SoftmaxRuleLayer, SparseRuleLayer, SimpleNousBlock
6
+
7
+ # Explainability (core API)
8
+ from .explain import (
9
+ rule_impact_df,
10
+ minimal_sufficient_explanation,
11
+ select_pruning_threshold_global,
12
+ select_pruning_threshold_global_bs,
13
+ global_rulebook,
14
+ generate_enhanced_explanation,
15
+ explanation_fidelity_metrics,
16
+ explanation_stability,
17
+ aggregator_mixture_report,
18
+ suggest_rule_counterfactuals,
19
+ render_fact_descriptions,
20
+ AGG_NAMES,
21
+ )
22
+ from .explain.aggregator import format_agg_mixture
23
+
24
+ # Prototype tracing utilities
25
+ from .explain.traces import (
26
+ describe_prototype,
27
+ prototype_report_global,
28
+ prototype_contribution_df,
29
+ prototype_top_rules,
30
+ trace_rule_to_base_facts,
31
+ get_last_block_static_metadata,
32
+ )
33
+
34
+ # Export utilities
35
+ from .export import (
36
+ export_numpy_inference,
37
+ validate_numpy_vs_torch,
38
+ export_and_validate,
39
+ load_numpy_module,
40
+ )
41
+
42
+ # Training and evaluation
43
+ from .training import (
44
+ train_model,
45
+ evaluate_classification,
46
+ evaluate_regression,
47
+ make_sparse_regression_hook,
48
+ )
49
+
50
+ # Dataset helpers (used in examples)
51
+ from .data import get_wine_data, get_california_housing_data
52
+
53
+ # Utilities
54
+ from .utils import set_global_seed
55
+
56
+ __all__ = [
57
+ "__version__",
58
+ # Core model and components
59
+ "NousNet",
60
+ "BetaFactLayer",
61
+ "PiecewiseLinearCalibrator",
62
+ "ScaledPrototypeLayer",
63
+ "FixedPairRuleLayer",
64
+ "SoftmaxRuleLayer",
65
+ "SparseRuleLayer",
66
+ "SimpleNousBlock",
67
+ # Explainability (core)
68
+ "rule_impact_df",
69
+ "minimal_sufficient_explanation",
70
+ "select_pruning_threshold_global",
71
+ "select_pruning_threshold_global_bs",
72
+ "global_rulebook",
73
+ "generate_enhanced_explanation",
74
+ "explanation_fidelity_metrics",
75
+ "explanation_stability",
76
+ "aggregator_mixture_report",
77
+ "suggest_rule_counterfactuals",
78
+ "render_fact_descriptions",
79
+ "AGG_NAMES",
80
+ "format_agg_mixture",
81
+ # Prototype tracing utilities
82
+ "describe_prototype",
83
+ "prototype_report_global",
84
+ "prototype_contribution_df",
85
+ "prototype_top_rules",
86
+ "trace_rule_to_base_facts",
87
+ "get_last_block_static_metadata",
88
+ # Export utilities
89
+ "export_numpy_inference",
90
+ "validate_numpy_vs_torch",
91
+ "export_and_validate",
92
+ "load_numpy_module",
93
+ # Training and evaluation
94
+ "train_model",
95
+ "evaluate_classification",
96
+ "evaluate_regression",
97
+ "make_sparse_regression_hook",
98
+ # Dataset helpers
99
+ "get_wine_data",
100
+ "get_california_housing_data",
101
+ # Utilities
102
+ "set_global_seed",
103
+ ]
nous/data/__init__.py ADDED
@@ -0,0 +1,4 @@
1
+ from .wine import get_wine_data
2
+ from .california import get_california_housing_data
3
+
4
+ __all__ = ["get_wine_data", "get_california_housing_data"]
@@ -0,0 +1,32 @@
1
+ from __future__ import annotations
2
+ import numpy as np
3
+ from typing import Tuple, List
4
+ from sklearn.model_selection import train_test_split
5
+ from sklearn.preprocessing import StandardScaler
6
+ from sklearn.datasets import fetch_california_housing
7
+
8
+ def get_california_housing_data(scale_y: bool = True):
9
+ """
10
+ Load California Housing and return standardized X and (optionally) standardized y.
11
+ Returns
12
+ -------
13
+ X_train, X_val, X_test, y_train, y_val, y_test, feature_names, class_names, task_type, y_scaler
14
+ """
15
+ data = fetch_california_housing()
16
+ X, y = data.data, data.target
17
+ feature_names = data.feature_names
18
+
19
+ X_train_full, X_test, y_train_full, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
20
+
21
+ x_scaler = StandardScaler()
22
+ X_train_full = x_scaler.fit_transform(X_train_full)
23
+ X_test = x_scaler.transform(X_test)
24
+
25
+ X_train, X_val, y_train, y_val = train_test_split(X_train_full, y_train_full, test_size=0.2, random_state=42)
26
+
27
+ y_scaler = None
28
+ if scale_y:
29
+ y_scaler = StandardScaler()
30
+ y_train = y_scaler.fit_transform(y_train.reshape(-1,1)).ravel()
31
+ y_val = y_scaler.transform(y_val.reshape(-1,1)).ravel()
32
+ return X_train, X_val, X_test, y_train, y_val, y_test, feature_names, None, "regression", y_scaler
nous/data/wine.py ADDED
@@ -0,0 +1,29 @@
1
+ from __future__ import annotations
2
+ import numpy as np
3
+ from typing import Tuple, List
4
+ from sklearn.model_selection import train_test_split
5
+ from sklearn.preprocessing import StandardScaler, LabelEncoder
6
+
7
+ def get_wine_data():
8
+ """
9
+ Load Wine dataset via ucimlrepo and return standardized splits.
10
+ Returns
11
+ -------
12
+ X_train, X_val, X_test, y_train, y_val, y_test, feature_names, class_names, task_type, y_scaler
13
+ """
14
+ from ucimlrepo import fetch_ucirepo
15
+ wine = fetch_ucirepo(id=109)
16
+ X, y_df = wine.data.features, wine.data.targets
17
+ feature_names = X.columns.tolist()
18
+ y = LabelEncoder().fit_transform(y_df.values.ravel())
19
+ class_names = [f"Class_{i+1}" for i in range(len(np.unique(y)))]
20
+ X_train_full, X_test, y_train_full, y_test = train_test_split(
21
+ X, y, test_size=0.2, random_state=42, stratify=y
22
+ )
23
+ preprocessor = StandardScaler()
24
+ X_train_full = preprocessor.fit_transform(X_train_full)
25
+ X_test = preprocessor.transform(X_test)
26
+ X_train, X_val, y_train, y_val = train_test_split(
27
+ X_train_full, y_train_full, test_size=0.2, random_state=42, stratify=y_train_full
28
+ )
29
+ return X_train, X_val, X_test, y_train, y_val, y_test, feature_names, class_names, "classification", None
@@ -0,0 +1,26 @@
1
+ from .aggregator import AGG_NAMES, aggregator_mixture_report, format_agg_mixture
2
+ from .facts_desc import render_fact_descriptions
3
+ from .loo import rule_impact_df
4
+ from .mse import minimal_sufficient_explanation
5
+ from .pruning import select_pruning_threshold_global, select_pruning_threshold_global_bs
6
+ from .global_book import global_rulebook
7
+ from .generate import generate_enhanced_explanation
8
+ from .fidelity import explanation_fidelity_metrics
9
+ from .stability import explanation_stability
10
+ from .cf import suggest_rule_counterfactuals
11
+
12
+ __all__ = [
13
+ "AGG_NAMES",
14
+ "aggregator_mixture_report",
15
+ "format_agg_mixture",
16
+ "render_fact_descriptions",
17
+ "rule_impact_df",
18
+ "minimal_sufficient_explanation",
19
+ "select_pruning_threshold_global",
20
+ "select_pruning_threshold_global_bs",
21
+ "global_rulebook",
22
+ "generate_enhanced_explanation",
23
+ "explanation_fidelity_metrics",
24
+ "explanation_stability",
25
+ "suggest_rule_counterfactuals",
26
+ ]
@@ -0,0 +1,34 @@
1
+ from __future__ import annotations
2
+ import numpy as np
3
+ import pandas as pd
4
+ from typing import List
5
+ from ..model import NousNet
6
+
7
+ AGG_NAMES = ['AND', 'OR', 'k-of-n', 'NOT']
8
+
9
+ def format_agg_mixture(weights) -> str:
10
+ parts = []
11
+ for i in range(weights.shape[0]):
12
+ w = float(weights[i])
13
+ if w > 1e-6:
14
+ parts.append(f"{w:.2f} {AGG_NAMES[i]}")
15
+ return " + ".join(parts) if parts else "∅"
16
+
17
+ def aggregator_mixture_report(model: NousNet, X, max_samples: int = 1000, device=None) -> pd.DataFrame:
18
+ device = device or next(model.parameters()).device
19
+ n = min(len(X), max_samples)
20
+ acc = []
21
+ for i in range(n):
22
+ _, _, internals = model.forward_explain(X[i], device=device)
23
+ for key in [k for k in internals.keys() if k.startswith("block_")]:
24
+ aw = internals[key]['aggregator_weights']
25
+ if aw is None:
26
+ continue
27
+ acc.append(aw.cpu().numpy())
28
+ if not acc:
29
+ return pd.DataFrame(columns=["AND", "OR", "k-of-n", "NOT", "entropy"])
30
+ A = np.concatenate(acc, axis=0)
31
+ mean = A.mean(axis=0)
32
+ ent = (-A * np.clip(np.log(A + 1e-12), -50, 50)).sum(axis=1).mean()
33
+ cols = AGG_NAMES[:A.shape[1]]
34
+ return pd.DataFrame([dict(**{c: float(v) for c, v in zip(cols, mean)}, entropy=float(ent))])
nous/explain/cf.py ADDED
@@ -0,0 +1,137 @@
1
+ from __future__ import annotations
2
+ import numpy as np
3
+ import torch
4
+ from typing import Optional, Sequence, List, Dict, Any
5
+ from ..model import NousNet
6
+ from .loo import rule_impact_df
7
+
8
+ def suggest_rule_counterfactuals(
9
+ model: NousNet, x_sample, feature_names: Sequence[str], class_names: Optional[Sequence[str]] = None,
10
+ target: str = "flip", # 'flip' (classification), 'margin_drop', 'reg_delta'
11
+ target_value: Optional[float] = None, # for margin_drop/reg_delta
12
+ y_scaler=None,
13
+ k_rules: int = 3, fact_target_level: float = 0.1, max_features: int = 2,
14
+ loo_mode: str = 'frozen', top_m_rules: int = 10, use_pre_norm: bool = False,
15
+ alphas: Sequence[float] = (0.5, 1.0, 1.5, 2.0),
16
+ device=None
17
+ ) -> List[Dict[str, Any]]:
18
+ """
19
+ Suggest counterfactual input deltas guided by influential rules using β-fact geometry.
20
+ Verifies suggested deltas by forward_explain.
21
+ """
22
+ device = device or next(model.parameters()).device
23
+ task = model.config['task_type']
24
+
25
+ base_probas, base_logits, base_internals = model.forward_explain(x_sample, device=device)
26
+ if task == "classification":
27
+ pred_idx = int(np.argmax(base_probas))
28
+ runner_up = int(np.argsort(base_logits)[-2]) if base_logits.size > 1 else pred_idx
29
+ base_margin = float(base_logits[pred_idx] - base_logits[runner_up])
30
+ else:
31
+ base_pred = float(base_logits[0])
32
+
33
+ imp = rule_impact_df(
34
+ model, x_sample, feature_names, class_names=class_names,
35
+ loo_mode=loo_mode, top_m_rules=top_m_rules, use_pre_norm=use_pre_norm
36
+ )
37
+ if imp.empty:
38
+ return []
39
+
40
+ if task == "classification":
41
+ margin_col = [c for c in imp.columns if c.startswith("Δmargin(")][0]
42
+ imp = imp.sort_values(by=margin_col, ascending=False)
43
+ else:
44
+ imp = imp.sort_values(by="Δprediction", ascending=False)
45
+ imp = imp.head(k_rules)
46
+
47
+ x = torch.tensor(x_sample, dtype=torch.float32, device=device).unsqueeze(0)
48
+ if model.calibrators is not None:
49
+ x_cal = torch.stack([calib(x[:, i]) for i, calib in enumerate(model.calibrators)], dim=1)
50
+ else:
51
+ x_cal = x
52
+ diff, k_vec, nu_vec, net_w = model.fact.compute_diff_and_params(x_cal) # [1,F], [F], [F], [F,D]
53
+ facts_act = model.fact(x_cal).squeeze(0)
54
+
55
+ suggestions = []
56
+ for _, row in imp.iterrows():
57
+ b = int(row["block"]) - 1
58
+ r = int(row["rule"]) - 1
59
+ details = base_internals[f'block_{b}']
60
+ facts_used = details.get("facts_used", None)
61
+ if isinstance(facts_used, torch.Tensor):
62
+ facts_used = facts_used.cpu().numpy()
63
+ if facts_used is None or facts_used.shape[0] <= r:
64
+ continue
65
+ used = facts_used[r]
66
+ used = [int(used)] if np.ndim(used) == 0 else [int(u) for u in used.tolist()]
67
+
68
+ used_sorted = sorted(used, key=lambda fid: float(facts_act[fid].item()), reverse=True)[:max(1, min(len(used), 2))]
69
+
70
+ deltas: Dict[int, float] = {}
71
+ for fid in used_sorted:
72
+ y_now = float(facts_act[fid].item()) + 1e-12
73
+ kf = float(k_vec[fid].item())
74
+ nuf = float(nu_vec[fid].item())
75
+ diff_now = float(diff[0, fid].item())
76
+ w = net_w[fid].detach().clone() # [D]
77
+
78
+ y_target = float(fact_target_level)
79
+ # Invert β: diff_target = (logit(y_target^(1/nu))) / k
80
+ diff_target = float(torch.logit(torch.tensor(y_target, device=device).pow(1.0/max(nuf,1e-6))))
81
+ diff_target = diff_target / max(kf, 1e-6)
82
+ delta_diff = diff_target - diff_now
83
+
84
+ w_np = w.cpu().numpy()
85
+ idxs = np.argsort(-np.abs(w_np))[:max_features]
86
+ w_sel = torch.zeros_like(w)
87
+ w_sel[idxs] = w[idxs]
88
+ denom = float(w_sel.pow(2).sum().item())
89
+ if denom < 1e-12:
90
+ continue
91
+ delta_x_cal = (delta_diff / denom) * w_sel # minimal L2 shift in x̃
92
+
93
+ delta_x = delta_x_cal.clone()
94
+ if model.calibrators is not None:
95
+ for i in idxs:
96
+ xi = x[0, i]
97
+ slope_i = model.calibrators[i].local_slope(xi)
98
+ delta_x[i] = delta_x_cal[i] / slope_i
99
+
100
+ for i in idxs:
101
+ deltas[i] = deltas.get(i, 0.0) + float(delta_x[i].item())
102
+
103
+ if not deltas:
104
+ continue
105
+
106
+ feat_deltas = sorted([(feature_names[i], d) for i, d in deltas.items()], key=lambda t: -abs(t[1]))
107
+ success = False
108
+ new_out = None
109
+ for a in alphas:
110
+ x_try = x.clone()
111
+ for i, d in deltas.items():
112
+ x_try[0, i] = x_try[0, i] + a * d
113
+ prob2, logit2, _ = model.forward_explain(x_try.squeeze(0).cpu().numpy(), device=device)
114
+ if task == "classification":
115
+ new_pred = int(np.argmax(prob2))
116
+ new_margin = float(logit2[pred_idx] - logit2[runner_up])
117
+ if target == "flip" and new_pred != pred_idx:
118
+ success, new_out = True, {"pred": new_pred, "margin": new_margin}
119
+ break
120
+ if target == "margin_drop" and target_value is not None and new_margin <= base_margin - float(target_value):
121
+ success, new_out = True, {"pred": new_pred, "margin": new_margin}
122
+ break
123
+ else:
124
+ new_pred = float(logit2[0])
125
+ if target == "reg_delta" and target_value is not None:
126
+ if (new_pred - base_pred) <= float(target_value):
127
+ success, new_out = True, {"pred": new_pred}
128
+ break
129
+
130
+ suggestions.append({
131
+ "rule": (b+1, r+1),
132
+ "facts": [f"F{fid+1}" for fid in used_sorted],
133
+ "deltas": feat_deltas,
134
+ "verified": success,
135
+ "new_out": new_out
136
+ })
137
+ return suggestions
@@ -0,0 +1,23 @@
1
+ from __future__ import annotations
2
+ from typing import Dict, Sequence
3
+ from ..model import NousNet
4
+
5
+ def render_fact_descriptions(model: NousNet, feature_names: Sequence[str], top_k_feats: int = 4, eps: float = 0.03) -> Dict[int, str]:
6
+ """
7
+ Create human-readable descriptions of base β-facts using (L-R) weights.
8
+ """
9
+ L, R, th, k, nu = model.fact.get_rule_parameters()
10
+ desc = {}
11
+ for fid in range(L.shape[0]):
12
+ net = L[fid] - R[fid]
13
+ pos = [(feature_names[i], net[i]) for i in range(len(net)) if net[i] > eps]
14
+ neg = [(feature_names[i], -net[i]) for i in range(len(net)) if net[i] < -eps]
15
+ pos = sorted(pos, key=lambda t: -abs(t[1]))[:top_k_feats]
16
+ neg = sorted(neg, key=lambda t: -abs(t[1]))[:top_k_feats]
17
+ pos_str = " + ".join([f"{w:.2f}·{n}" for n, w in pos]) if pos else "0"
18
+ neg_str = " + ".join([f"{w:.2f}·{n}" for n, w in neg]) if neg else "0"
19
+ base = f"β( [L−R](x̃) = ({pos_str}) − ({neg_str}) > {th[fid]:.2f}; k={k[fid]:.2f}, ν={nu[fid]:.2f} )"
20
+ if model.calibrators is not None:
21
+ base += " where x̃ are calibrated features"
22
+ desc[fid] = base
23
+ return desc
@@ -0,0 +1,56 @@
1
+ from __future__ import annotations
2
+ import numpy as np
3
+ import torch
4
+ from typing import Optional, Sequence, Dict, Any
5
+ from ..model import NousNet
6
+ from .mse import minimal_sufficient_explanation
7
+
8
+ def explanation_fidelity_metrics(
9
+ model: NousNet, x_sample, feature_names: Sequence[str], class_names: Optional[Sequence[str]] = None,
10
+ margin_tolerance: float = 0.0, prob_tolerance: float = 0.0, pred_tolerance_reg: float = 0.02,
11
+ device=None
12
+ ) -> Dict[str, float]:
13
+ """
14
+ Sufficiency and comprehensiveness metrics using MSE masks.
15
+ """
16
+ device = device or next(model.parameters()).device
17
+ task = model.config['task_type']
18
+
19
+ base_probas, base_logits, base_internals = model.forward_explain(x_sample, device=device)
20
+ if task == "classification":
21
+ pred_idx = int(np.argmax(base_probas))
22
+ runner_up = int(np.argsort(base_logits)[-2]) if base_logits.size > 1 else pred_idx
23
+ base_margin = float(base_logits[pred_idx] - base_logits[runner_up])
24
+ else:
25
+ base_pred = float(base_logits[0])
26
+
27
+ mse = minimal_sufficient_explanation(
28
+ model, x_sample, feature_names, class_names=class_names,
29
+ margin_tolerance=margin_tolerance, prob_tolerance=prob_tolerance,
30
+ pred_tolerance_reg=pred_tolerance_reg, freeze_non_active=True, device=device
31
+ )
32
+ kept_masks = mse["kept_masks"]
33
+
34
+ prob_s, logit_s, _ = model.forward_explain(x_sample, restrict_masks=kept_masks, device=device)
35
+ inv_masks = [(torch.ones_like(m) - m) for m in kept_masks]
36
+ prob_c, logit_c, _ = model.forward_explain(x_sample, restrict_masks=inv_masks, device=device)
37
+
38
+ if task == "classification":
39
+ runner_up = int(np.argsort(base_logits)[-2]) if base_logits.size > 1 else pred_idx
40
+ margin_s = float(logit_s[pred_idx] - logit_s[runner_up])
41
+ margin_c = float(logit_c[pred_idx] - logit_c[runner_up])
42
+ return {
43
+ "base_margin": base_margin,
44
+ "sufficiency_margin": margin_s,
45
+ "comprehensiveness_margin": margin_c,
46
+ "kept_size": float(sum(int(m.sum().item()) for m in kept_masks))
47
+ }
48
+ else:
49
+ pred_s = float(logit_s[0])
50
+ pred_c = float(logit_c[0])
51
+ return {
52
+ "base_pred": base_pred,
53
+ "sufficiency_pred": pred_s,
54
+ "comprehensiveness_pred": pred_c,
55
+ "kept_size": float(sum(int(m.sum().item()) for m in kept_masks))
56
+ }
@@ -0,0 +1,86 @@
1
+ from __future__ import annotations
2
+ import numpy as np
3
+ from typing import Optional, Sequence
4
+ from ..model import NousNet
5
+ from .loo import rule_impact_df
6
+
7
+ def generate_enhanced_explanation(
8
+ model: NousNet, x_sample, y_true, feature_names: Sequence[str], class_names: Optional[Sequence[str]] = None, y_scaler=None,
9
+ loo_mode: str = 'replacement', top_m_rules: Optional[int] = None, use_pre_norm: bool = False,
10
+ use_pruning: bool = False, pruning_threshold: float = 0.0, vs_class_idx: Optional[int] = None
11
+ ) -> str:
12
+ """
13
+ Human-readable explanation text with top rule impacts for a single sample.
14
+ """
15
+ task = model.config['task_type']
16
+ base_probas, base_logits, _ = model.forward_explain(
17
+ x_sample, apply_pruning=use_pruning, pruning_threshold=pruning_threshold
18
+ )
19
+
20
+ if y_scaler is not None:
21
+ y_true_unscaled = y_scaler.inverse_transform(np.array(y_true).reshape(-1, 1)).item()
22
+ base_pred_unscaled = y_scaler.inverse_transform(base_logits.reshape(-1, 1)).item()
23
+ else:
24
+ y_true_unscaled = y_true
25
+ base_pred_unscaled = base_logits[0] if task == 'regression' else None
26
+
27
+ if task == "classification":
28
+ pred_idx = int(np.argmax(base_probas))
29
+ pred_name = class_names[pred_idx] if class_names else f"Class {pred_idx}"
30
+ conf = float(base_probas[pred_idx])
31
+ true_name = class_names[y_true] if class_names is not None else str(y_true)
32
+ else:
33
+ pred_name = f"Value: {base_pred_unscaled:.3f}"
34
+ true_name = f"{y_true_unscaled:.3f}"
35
+
36
+ lines = []
37
+ model_tag = model.config['rule_selection_method'].upper()
38
+ lines.append(f"MODEL: {model_tag} rules | TASK: {task.upper()}")
39
+ lines.append(f"SAMPLE PREDICTION: {pred_name}")
40
+ if task == "classification":
41
+ lines.append(f" - Confidence: {conf:.3f}")
42
+ lines.append(f" - Ground Truth: {true_name}")
43
+ else:
44
+ lines.append(f" - Ground Truth: {true_name}")
45
+ if use_pruning:
46
+ lines.append(f" - Pruning: |act| >= {pruning_threshold:.4f} (forward uses pruned activations)")
47
+ lines.append("-"*60)
48
+
49
+ imp = rule_impact_df(
50
+ model, x_sample, feature_names, class_names=class_names, vs_class=vs_class_idx,
51
+ loo_mode=loo_mode, top_m_rules=top_m_rules, use_pre_norm=use_pre_norm,
52
+ use_pruning=use_pruning, pruning_threshold=pruning_threshold
53
+ )
54
+ if imp.empty:
55
+ lines.append("No active rules above threshold.")
56
+ return "\n".join(lines)
57
+
58
+ if vs_class_idx is not None and task == "classification":
59
+ lines.append(f"(Contrastive) Why '{pred_name}' vs '{class_names[vs_class_idx]}'?")
60
+ lines.append("CAUSAL RULE IMPACT (Top 5):")
61
+
62
+ for _, row in imp.head(5).iterrows():
63
+ b, r = row['block'], row['rule']
64
+ aggs, facts = row['aggregators'], row['facts']
65
+ badge = ""
66
+ if 'supports_pred' in row:
67
+ badge = " [+]" if row['supports_pred'] else " [-]"
68
+ if 'critical_flip' in row and row['critical_flip']:
69
+ badge += " [CRITICAL]"
70
+
71
+ if task == "classification":
72
+ margin_col = [c for c in imp.columns if c.startswith("Δmargin(")][0]
73
+ s = f"Δmargin={row[margin_col]:+.3f}{badge}"
74
+ else:
75
+ delta_pred = row['Δprediction']
76
+ if y_scaler is not None:
77
+ delta_pred_unscaled = delta_pred * y_scaler.scale_[0]
78
+ s = f"Δprediction={delta_pred_unscaled:+.3f}"
79
+ else:
80
+ s = f"Δprediction={delta_pred:+.3f}"
81
+ lines.append(f" • B{b}/R{r}: {s} | {aggs}\n {facts}")
82
+
83
+ if len(imp) > 5:
84
+ lines.append(f" ... and {len(imp) - 5} more active rules.")
85
+
86
+ return "\n".join(lines)
@@ -0,0 +1,52 @@
1
+ from __future__ import annotations
2
+ import numpy as np
3
+ import pandas as pd
4
+ from typing import Optional, Sequence
5
+ from ..model import NousNet
6
+ from .loo import rule_impact_df
7
+
8
+ def global_rulebook(
9
+ model: NousNet, X, feature_names: Sequence[str], class_names: Optional[Sequence[str]] = None,
10
+ pruning_threshold: Optional[float] = None, use_pruning: bool = False, allow_replacement: bool = True, freeze_non_active: bool = False,
11
+ max_samples: int = 1000, device=None
12
+ ) -> pd.DataFrame:
13
+ """
14
+ Global rulebook aggregation across samples via honest LOO impacts.
15
+ """
16
+ from tqdm.auto import tqdm # optional dep in extras
17
+
18
+ device = device or next(model.parameters()).device
19
+ n = min(len(X), max_samples)
20
+ totals = {}
21
+
22
+ loo_mode = 'replacement'
23
+ if not allow_replacement and freeze_non_active:
24
+ loo_mode = 'frozen'
25
+
26
+ for i in tqdm(range(n), desc="Analyzing samples"):
27
+ df = rule_impact_df(
28
+ model, X[i], feature_names, class_names=class_names,
29
+ loo_mode=loo_mode, use_pruning=use_pruning, pruning_threshold=(pruning_threshold or 0.0), device=device
30
+ )
31
+ if df.empty:
32
+ continue
33
+ for _, row in df.iterrows():
34
+ key = (int(row["block"]), int(row["rule"]), row["aggregators"])
35
+ d = totals.setdefault(key, {"count": 0, "sum_abs_impact": 0.0, "critical": 0})
36
+ d["count"] += 1
37
+ if "critical_flip" in row and row["critical_flip"]:
38
+ d["critical"] += 1
39
+ metric_col = [c for c in row.index if c.startswith("Δmargin(")]
40
+ val = abs(float(row[metric_col[0]])) if metric_col else abs(float(row.get("Δprediction", 0.0)))
41
+ d["sum_abs_impact"] += val
42
+
43
+ rows = []
44
+ for (b, r, agg), v in totals.items():
45
+ count = v["count"]
46
+ rows.append({
47
+ "block": b, "rule": r, "aggregators": agg,
48
+ "activation_freq": count / n,
49
+ "mean_abs_impact": v["sum_abs_impact"] / count if count else 0.0,
50
+ "critical_rate": v["critical"] / count if count else 0.0
51
+ })
52
+ return pd.DataFrame(rows).sort_values("mean_abs_impact", ascending=False).reset_index(drop=True)