nous 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nous might be problematic. Click here for more details.
- nous/__init__.py +96 -19
- nous/data/__init__.py +4 -0
- nous/data/california.py +32 -0
- nous/data/wine.py +29 -0
- nous/explain/__init__.py +26 -0
- nous/explain/aggregator.py +34 -0
- nous/explain/cf.py +137 -0
- nous/explain/facts_desc.py +23 -0
- nous/explain/fidelity.py +56 -0
- nous/explain/generate.py +86 -0
- nous/explain/global_book.py +52 -0
- nous/explain/loo.py +130 -0
- nous/explain/mse.py +93 -0
- nous/explain/pruning.py +117 -0
- nous/explain/stability.py +42 -0
- nous/explain/traces.py +285 -0
- nous/explain/utils.py +15 -0
- nous/export/__init__.py +13 -0
- nous/export/numpy_infer.py +412 -0
- nous/facts.py +112 -0
- nous/model.py +226 -0
- nous/prototypes.py +43 -0
- nous/rules/__init__.py +11 -0
- nous/rules/blocks.py +63 -0
- nous/rules/fixed.py +26 -0
- nous/rules/softmax.py +93 -0
- nous/rules/sparse.py +142 -0
- nous/training/__init__.py +5 -0
- nous/training/evaluation.py +57 -0
- nous/training/schedulers.py +34 -0
- nous/training/train.py +177 -0
- nous/types.py +4 -0
- nous/utils/__init__.py +3 -0
- nous/utils/metrics.py +2 -0
- nous/utils/seed.py +13 -0
- nous/version.py +1 -0
- nous-0.2.0.dist-info/METADATA +150 -0
- nous-0.2.0.dist-info/RECORD +41 -0
- nous/causal.py +0 -63
- nous/interpret.py +0 -111
- nous/layers.py +0 -117
- nous/models.py +0 -65
- nous-0.1.0.dist-info/METADATA +0 -138
- nous-0.1.0.dist-info/RECORD +0 -10
- {nous-0.1.0.dist-info → nous-0.2.0.dist-info}/WHEEL +0 -0
- {nous-0.1.0.dist-info → nous-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {nous-0.1.0.dist-info → nous-0.2.0.dist-info}/top_level.txt +0 -0
nous/__init__.py
CHANGED
|
@@ -1,26 +1,103 @@
|
|
|
1
|
-
|
|
1
|
+
from .version import __version__
|
|
2
|
+
from .model import NousNet
|
|
3
|
+
from .facts import BetaFactLayer, PiecewiseLinearCalibrator
|
|
4
|
+
from .prototypes import ScaledPrototypeLayer
|
|
5
|
+
from .rules import FixedPairRuleLayer, SoftmaxRuleLayer, SparseRuleLayer, SimpleNousBlock
|
|
2
6
|
|
|
3
|
-
|
|
7
|
+
# Explainability (core API)
|
|
8
|
+
from .explain import (
|
|
9
|
+
rule_impact_df,
|
|
10
|
+
minimal_sufficient_explanation,
|
|
11
|
+
select_pruning_threshold_global,
|
|
12
|
+
select_pruning_threshold_global_bs,
|
|
13
|
+
global_rulebook,
|
|
14
|
+
generate_enhanced_explanation,
|
|
15
|
+
explanation_fidelity_metrics,
|
|
16
|
+
explanation_stability,
|
|
17
|
+
aggregator_mixture_report,
|
|
18
|
+
suggest_rule_counterfactuals,
|
|
19
|
+
render_fact_descriptions,
|
|
20
|
+
AGG_NAMES,
|
|
21
|
+
)
|
|
22
|
+
from .explain.aggregator import format_agg_mixture
|
|
23
|
+
|
|
24
|
+
# Prototype tracing utilities
|
|
25
|
+
from .explain.traces import (
|
|
26
|
+
describe_prototype,
|
|
27
|
+
prototype_report_global,
|
|
28
|
+
prototype_contribution_df,
|
|
29
|
+
prototype_top_rules,
|
|
30
|
+
trace_rule_to_base_facts,
|
|
31
|
+
get_last_block_static_metadata,
|
|
32
|
+
)
|
|
4
33
|
|
|
5
|
-
|
|
6
|
-
from .
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
plot_final_layer_contributions,
|
|
34
|
+
# Export utilities
|
|
35
|
+
from .export import (
|
|
36
|
+
export_numpy_inference,
|
|
37
|
+
validate_numpy_vs_torch,
|
|
38
|
+
export_and_validate,
|
|
39
|
+
load_numpy_module,
|
|
12
40
|
)
|
|
13
|
-
|
|
41
|
+
|
|
42
|
+
# Training and evaluation
|
|
43
|
+
from .training import (
|
|
44
|
+
train_model,
|
|
45
|
+
evaluate_classification,
|
|
46
|
+
evaluate_regression,
|
|
47
|
+
make_sparse_regression_hook,
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# Dataset helpers (used in examples)
|
|
51
|
+
from .data import get_wine_data, get_california_housing_data
|
|
52
|
+
|
|
53
|
+
# Utilities
|
|
54
|
+
from .utils import set_global_seed
|
|
14
55
|
|
|
15
56
|
__all__ = [
|
|
16
|
-
|
|
57
|
+
"__version__",
|
|
58
|
+
# Core model and components
|
|
17
59
|
"NousNet",
|
|
18
|
-
|
|
19
|
-
"
|
|
20
|
-
"
|
|
21
|
-
"
|
|
22
|
-
"
|
|
23
|
-
"
|
|
24
|
-
|
|
25
|
-
|
|
60
|
+
"BetaFactLayer",
|
|
61
|
+
"PiecewiseLinearCalibrator",
|
|
62
|
+
"ScaledPrototypeLayer",
|
|
63
|
+
"FixedPairRuleLayer",
|
|
64
|
+
"SoftmaxRuleLayer",
|
|
65
|
+
"SparseRuleLayer",
|
|
66
|
+
"SimpleNousBlock",
|
|
67
|
+
# Explainability (core)
|
|
68
|
+
"rule_impact_df",
|
|
69
|
+
"minimal_sufficient_explanation",
|
|
70
|
+
"select_pruning_threshold_global",
|
|
71
|
+
"select_pruning_threshold_global_bs",
|
|
72
|
+
"global_rulebook",
|
|
73
|
+
"generate_enhanced_explanation",
|
|
74
|
+
"explanation_fidelity_metrics",
|
|
75
|
+
"explanation_stability",
|
|
76
|
+
"aggregator_mixture_report",
|
|
77
|
+
"suggest_rule_counterfactuals",
|
|
78
|
+
"render_fact_descriptions",
|
|
79
|
+
"AGG_NAMES",
|
|
80
|
+
"format_agg_mixture",
|
|
81
|
+
# Prototype tracing utilities
|
|
82
|
+
"describe_prototype",
|
|
83
|
+
"prototype_report_global",
|
|
84
|
+
"prototype_contribution_df",
|
|
85
|
+
"prototype_top_rules",
|
|
86
|
+
"trace_rule_to_base_facts",
|
|
87
|
+
"get_last_block_static_metadata",
|
|
88
|
+
# Export utilities
|
|
89
|
+
"export_numpy_inference",
|
|
90
|
+
"validate_numpy_vs_torch",
|
|
91
|
+
"export_and_validate",
|
|
92
|
+
"load_numpy_module",
|
|
93
|
+
# Training and evaluation
|
|
94
|
+
"train_model",
|
|
95
|
+
"evaluate_classification",
|
|
96
|
+
"evaluate_regression",
|
|
97
|
+
"make_sparse_regression_hook",
|
|
98
|
+
# Dataset helpers
|
|
99
|
+
"get_wine_data",
|
|
100
|
+
"get_california_housing_data",
|
|
101
|
+
# Utilities
|
|
102
|
+
"set_global_seed",
|
|
26
103
|
]
|
nous/data/__init__.py
ADDED
nous/data/california.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import numpy as np
|
|
3
|
+
from typing import Tuple, List
|
|
4
|
+
from sklearn.model_selection import train_test_split
|
|
5
|
+
from sklearn.preprocessing import StandardScaler
|
|
6
|
+
from sklearn.datasets import fetch_california_housing
|
|
7
|
+
|
|
8
|
+
def get_california_housing_data(scale_y: bool = True):
|
|
9
|
+
"""
|
|
10
|
+
Load California Housing and return standardized X and (optionally) standardized y.
|
|
11
|
+
Returns
|
|
12
|
+
-------
|
|
13
|
+
X_train, X_val, X_test, y_train, y_val, y_test, feature_names, class_names, task_type, y_scaler
|
|
14
|
+
"""
|
|
15
|
+
data = fetch_california_housing()
|
|
16
|
+
X, y = data.data, data.target
|
|
17
|
+
feature_names = data.feature_names
|
|
18
|
+
|
|
19
|
+
X_train_full, X_test, y_train_full, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
|
20
|
+
|
|
21
|
+
x_scaler = StandardScaler()
|
|
22
|
+
X_train_full = x_scaler.fit_transform(X_train_full)
|
|
23
|
+
X_test = x_scaler.transform(X_test)
|
|
24
|
+
|
|
25
|
+
X_train, X_val, y_train, y_val = train_test_split(X_train_full, y_train_full, test_size=0.2, random_state=42)
|
|
26
|
+
|
|
27
|
+
y_scaler = None
|
|
28
|
+
if scale_y:
|
|
29
|
+
y_scaler = StandardScaler()
|
|
30
|
+
y_train = y_scaler.fit_transform(y_train.reshape(-1,1)).ravel()
|
|
31
|
+
y_val = y_scaler.transform(y_val.reshape(-1,1)).ravel()
|
|
32
|
+
return X_train, X_val, X_test, y_train, y_val, y_test, feature_names, None, "regression", y_scaler
|
nous/data/wine.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import numpy as np
|
|
3
|
+
from typing import Tuple, List
|
|
4
|
+
from sklearn.model_selection import train_test_split
|
|
5
|
+
from sklearn.preprocessing import StandardScaler, LabelEncoder
|
|
6
|
+
|
|
7
|
+
def get_wine_data():
|
|
8
|
+
"""
|
|
9
|
+
Load Wine dataset via ucimlrepo and return standardized splits.
|
|
10
|
+
Returns
|
|
11
|
+
-------
|
|
12
|
+
X_train, X_val, X_test, y_train, y_val, y_test, feature_names, class_names, task_type, y_scaler
|
|
13
|
+
"""
|
|
14
|
+
from ucimlrepo import fetch_ucirepo
|
|
15
|
+
wine = fetch_ucirepo(id=109)
|
|
16
|
+
X, y_df = wine.data.features, wine.data.targets
|
|
17
|
+
feature_names = X.columns.tolist()
|
|
18
|
+
y = LabelEncoder().fit_transform(y_df.values.ravel())
|
|
19
|
+
class_names = [f"Class_{i+1}" for i in range(len(np.unique(y)))]
|
|
20
|
+
X_train_full, X_test, y_train_full, y_test = train_test_split(
|
|
21
|
+
X, y, test_size=0.2, random_state=42, stratify=y
|
|
22
|
+
)
|
|
23
|
+
preprocessor = StandardScaler()
|
|
24
|
+
X_train_full = preprocessor.fit_transform(X_train_full)
|
|
25
|
+
X_test = preprocessor.transform(X_test)
|
|
26
|
+
X_train, X_val, y_train, y_val = train_test_split(
|
|
27
|
+
X_train_full, y_train_full, test_size=0.2, random_state=42, stratify=y_train_full
|
|
28
|
+
)
|
|
29
|
+
return X_train, X_val, X_test, y_train, y_val, y_test, feature_names, class_names, "classification", None
|
nous/explain/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from .aggregator import AGG_NAMES, aggregator_mixture_report, format_agg_mixture
|
|
2
|
+
from .facts_desc import render_fact_descriptions
|
|
3
|
+
from .loo import rule_impact_df
|
|
4
|
+
from .mse import minimal_sufficient_explanation
|
|
5
|
+
from .pruning import select_pruning_threshold_global, select_pruning_threshold_global_bs
|
|
6
|
+
from .global_book import global_rulebook
|
|
7
|
+
from .generate import generate_enhanced_explanation
|
|
8
|
+
from .fidelity import explanation_fidelity_metrics
|
|
9
|
+
from .stability import explanation_stability
|
|
10
|
+
from .cf import suggest_rule_counterfactuals
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"AGG_NAMES",
|
|
14
|
+
"aggregator_mixture_report",
|
|
15
|
+
"format_agg_mixture",
|
|
16
|
+
"render_fact_descriptions",
|
|
17
|
+
"rule_impact_df",
|
|
18
|
+
"minimal_sufficient_explanation",
|
|
19
|
+
"select_pruning_threshold_global",
|
|
20
|
+
"select_pruning_threshold_global_bs",
|
|
21
|
+
"global_rulebook",
|
|
22
|
+
"generate_enhanced_explanation",
|
|
23
|
+
"explanation_fidelity_metrics",
|
|
24
|
+
"explanation_stability",
|
|
25
|
+
"suggest_rule_counterfactuals",
|
|
26
|
+
]
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
4
|
+
from typing import List
|
|
5
|
+
from ..model import NousNet
|
|
6
|
+
|
|
7
|
+
AGG_NAMES = ['AND', 'OR', 'k-of-n', 'NOT']
|
|
8
|
+
|
|
9
|
+
def format_agg_mixture(weights) -> str:
|
|
10
|
+
parts = []
|
|
11
|
+
for i in range(weights.shape[0]):
|
|
12
|
+
w = float(weights[i])
|
|
13
|
+
if w > 1e-6:
|
|
14
|
+
parts.append(f"{w:.2f} {AGG_NAMES[i]}")
|
|
15
|
+
return " + ".join(parts) if parts else "∅"
|
|
16
|
+
|
|
17
|
+
def aggregator_mixture_report(model: NousNet, X, max_samples: int = 1000, device=None) -> pd.DataFrame:
|
|
18
|
+
device = device or next(model.parameters()).device
|
|
19
|
+
n = min(len(X), max_samples)
|
|
20
|
+
acc = []
|
|
21
|
+
for i in range(n):
|
|
22
|
+
_, _, internals = model.forward_explain(X[i], device=device)
|
|
23
|
+
for key in [k for k in internals.keys() if k.startswith("block_")]:
|
|
24
|
+
aw = internals[key]['aggregator_weights']
|
|
25
|
+
if aw is None:
|
|
26
|
+
continue
|
|
27
|
+
acc.append(aw.cpu().numpy())
|
|
28
|
+
if not acc:
|
|
29
|
+
return pd.DataFrame(columns=["AND", "OR", "k-of-n", "NOT", "entropy"])
|
|
30
|
+
A = np.concatenate(acc, axis=0)
|
|
31
|
+
mean = A.mean(axis=0)
|
|
32
|
+
ent = (-A * np.clip(np.log(A + 1e-12), -50, 50)).sum(axis=1).mean()
|
|
33
|
+
cols = AGG_NAMES[:A.shape[1]]
|
|
34
|
+
return pd.DataFrame([dict(**{c: float(v) for c, v in zip(cols, mean)}, entropy=float(ent))])
|
nous/explain/cf.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import numpy as np
|
|
3
|
+
import torch
|
|
4
|
+
from typing import Optional, Sequence, List, Dict, Any
|
|
5
|
+
from ..model import NousNet
|
|
6
|
+
from .loo import rule_impact_df
|
|
7
|
+
|
|
8
|
+
def suggest_rule_counterfactuals(
|
|
9
|
+
model: NousNet, x_sample, feature_names: Sequence[str], class_names: Optional[Sequence[str]] = None,
|
|
10
|
+
target: str = "flip", # 'flip' (classification), 'margin_drop', 'reg_delta'
|
|
11
|
+
target_value: Optional[float] = None, # for margin_drop/reg_delta
|
|
12
|
+
y_scaler=None,
|
|
13
|
+
k_rules: int = 3, fact_target_level: float = 0.1, max_features: int = 2,
|
|
14
|
+
loo_mode: str = 'frozen', top_m_rules: int = 10, use_pre_norm: bool = False,
|
|
15
|
+
alphas: Sequence[float] = (0.5, 1.0, 1.5, 2.0),
|
|
16
|
+
device=None
|
|
17
|
+
) -> List[Dict[str, Any]]:
|
|
18
|
+
"""
|
|
19
|
+
Suggest counterfactual input deltas guided by influential rules using β-fact geometry.
|
|
20
|
+
Verifies suggested deltas by forward_explain.
|
|
21
|
+
"""
|
|
22
|
+
device = device or next(model.parameters()).device
|
|
23
|
+
task = model.config['task_type']
|
|
24
|
+
|
|
25
|
+
base_probas, base_logits, base_internals = model.forward_explain(x_sample, device=device)
|
|
26
|
+
if task == "classification":
|
|
27
|
+
pred_idx = int(np.argmax(base_probas))
|
|
28
|
+
runner_up = int(np.argsort(base_logits)[-2]) if base_logits.size > 1 else pred_idx
|
|
29
|
+
base_margin = float(base_logits[pred_idx] - base_logits[runner_up])
|
|
30
|
+
else:
|
|
31
|
+
base_pred = float(base_logits[0])
|
|
32
|
+
|
|
33
|
+
imp = rule_impact_df(
|
|
34
|
+
model, x_sample, feature_names, class_names=class_names,
|
|
35
|
+
loo_mode=loo_mode, top_m_rules=top_m_rules, use_pre_norm=use_pre_norm
|
|
36
|
+
)
|
|
37
|
+
if imp.empty:
|
|
38
|
+
return []
|
|
39
|
+
|
|
40
|
+
if task == "classification":
|
|
41
|
+
margin_col = [c for c in imp.columns if c.startswith("Δmargin(")][0]
|
|
42
|
+
imp = imp.sort_values(by=margin_col, ascending=False)
|
|
43
|
+
else:
|
|
44
|
+
imp = imp.sort_values(by="Δprediction", ascending=False)
|
|
45
|
+
imp = imp.head(k_rules)
|
|
46
|
+
|
|
47
|
+
x = torch.tensor(x_sample, dtype=torch.float32, device=device).unsqueeze(0)
|
|
48
|
+
if model.calibrators is not None:
|
|
49
|
+
x_cal = torch.stack([calib(x[:, i]) for i, calib in enumerate(model.calibrators)], dim=1)
|
|
50
|
+
else:
|
|
51
|
+
x_cal = x
|
|
52
|
+
diff, k_vec, nu_vec, net_w = model.fact.compute_diff_and_params(x_cal) # [1,F], [F], [F], [F,D]
|
|
53
|
+
facts_act = model.fact(x_cal).squeeze(0)
|
|
54
|
+
|
|
55
|
+
suggestions = []
|
|
56
|
+
for _, row in imp.iterrows():
|
|
57
|
+
b = int(row["block"]) - 1
|
|
58
|
+
r = int(row["rule"]) - 1
|
|
59
|
+
details = base_internals[f'block_{b}']
|
|
60
|
+
facts_used = details.get("facts_used", None)
|
|
61
|
+
if isinstance(facts_used, torch.Tensor):
|
|
62
|
+
facts_used = facts_used.cpu().numpy()
|
|
63
|
+
if facts_used is None or facts_used.shape[0] <= r:
|
|
64
|
+
continue
|
|
65
|
+
used = facts_used[r]
|
|
66
|
+
used = [int(used)] if np.ndim(used) == 0 else [int(u) for u in used.tolist()]
|
|
67
|
+
|
|
68
|
+
used_sorted = sorted(used, key=lambda fid: float(facts_act[fid].item()), reverse=True)[:max(1, min(len(used), 2))]
|
|
69
|
+
|
|
70
|
+
deltas: Dict[int, float] = {}
|
|
71
|
+
for fid in used_sorted:
|
|
72
|
+
y_now = float(facts_act[fid].item()) + 1e-12
|
|
73
|
+
kf = float(k_vec[fid].item())
|
|
74
|
+
nuf = float(nu_vec[fid].item())
|
|
75
|
+
diff_now = float(diff[0, fid].item())
|
|
76
|
+
w = net_w[fid].detach().clone() # [D]
|
|
77
|
+
|
|
78
|
+
y_target = float(fact_target_level)
|
|
79
|
+
# Invert β: diff_target = (logit(y_target^(1/nu))) / k
|
|
80
|
+
diff_target = float(torch.logit(torch.tensor(y_target, device=device).pow(1.0/max(nuf,1e-6))))
|
|
81
|
+
diff_target = diff_target / max(kf, 1e-6)
|
|
82
|
+
delta_diff = diff_target - diff_now
|
|
83
|
+
|
|
84
|
+
w_np = w.cpu().numpy()
|
|
85
|
+
idxs = np.argsort(-np.abs(w_np))[:max_features]
|
|
86
|
+
w_sel = torch.zeros_like(w)
|
|
87
|
+
w_sel[idxs] = w[idxs]
|
|
88
|
+
denom = float(w_sel.pow(2).sum().item())
|
|
89
|
+
if denom < 1e-12:
|
|
90
|
+
continue
|
|
91
|
+
delta_x_cal = (delta_diff / denom) * w_sel # minimal L2 shift in x̃
|
|
92
|
+
|
|
93
|
+
delta_x = delta_x_cal.clone()
|
|
94
|
+
if model.calibrators is not None:
|
|
95
|
+
for i in idxs:
|
|
96
|
+
xi = x[0, i]
|
|
97
|
+
slope_i = model.calibrators[i].local_slope(xi)
|
|
98
|
+
delta_x[i] = delta_x_cal[i] / slope_i
|
|
99
|
+
|
|
100
|
+
for i in idxs:
|
|
101
|
+
deltas[i] = deltas.get(i, 0.0) + float(delta_x[i].item())
|
|
102
|
+
|
|
103
|
+
if not deltas:
|
|
104
|
+
continue
|
|
105
|
+
|
|
106
|
+
feat_deltas = sorted([(feature_names[i], d) for i, d in deltas.items()], key=lambda t: -abs(t[1]))
|
|
107
|
+
success = False
|
|
108
|
+
new_out = None
|
|
109
|
+
for a in alphas:
|
|
110
|
+
x_try = x.clone()
|
|
111
|
+
for i, d in deltas.items():
|
|
112
|
+
x_try[0, i] = x_try[0, i] + a * d
|
|
113
|
+
prob2, logit2, _ = model.forward_explain(x_try.squeeze(0).cpu().numpy(), device=device)
|
|
114
|
+
if task == "classification":
|
|
115
|
+
new_pred = int(np.argmax(prob2))
|
|
116
|
+
new_margin = float(logit2[pred_idx] - logit2[runner_up])
|
|
117
|
+
if target == "flip" and new_pred != pred_idx:
|
|
118
|
+
success, new_out = True, {"pred": new_pred, "margin": new_margin}
|
|
119
|
+
break
|
|
120
|
+
if target == "margin_drop" and target_value is not None and new_margin <= base_margin - float(target_value):
|
|
121
|
+
success, new_out = True, {"pred": new_pred, "margin": new_margin}
|
|
122
|
+
break
|
|
123
|
+
else:
|
|
124
|
+
new_pred = float(logit2[0])
|
|
125
|
+
if target == "reg_delta" and target_value is not None:
|
|
126
|
+
if (new_pred - base_pred) <= float(target_value):
|
|
127
|
+
success, new_out = True, {"pred": new_pred}
|
|
128
|
+
break
|
|
129
|
+
|
|
130
|
+
suggestions.append({
|
|
131
|
+
"rule": (b+1, r+1),
|
|
132
|
+
"facts": [f"F{fid+1}" for fid in used_sorted],
|
|
133
|
+
"deltas": feat_deltas,
|
|
134
|
+
"verified": success,
|
|
135
|
+
"new_out": new_out
|
|
136
|
+
})
|
|
137
|
+
return suggestions
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Dict, Sequence
|
|
3
|
+
from ..model import NousNet
|
|
4
|
+
|
|
5
|
+
def render_fact_descriptions(model: NousNet, feature_names: Sequence[str], top_k_feats: int = 4, eps: float = 0.03) -> Dict[int, str]:
|
|
6
|
+
"""
|
|
7
|
+
Create human-readable descriptions of base β-facts using (L-R) weights.
|
|
8
|
+
"""
|
|
9
|
+
L, R, th, k, nu = model.fact.get_rule_parameters()
|
|
10
|
+
desc = {}
|
|
11
|
+
for fid in range(L.shape[0]):
|
|
12
|
+
net = L[fid] - R[fid]
|
|
13
|
+
pos = [(feature_names[i], net[i]) for i in range(len(net)) if net[i] > eps]
|
|
14
|
+
neg = [(feature_names[i], -net[i]) for i in range(len(net)) if net[i] < -eps]
|
|
15
|
+
pos = sorted(pos, key=lambda t: -abs(t[1]))[:top_k_feats]
|
|
16
|
+
neg = sorted(neg, key=lambda t: -abs(t[1]))[:top_k_feats]
|
|
17
|
+
pos_str = " + ".join([f"{w:.2f}·{n}" for n, w in pos]) if pos else "0"
|
|
18
|
+
neg_str = " + ".join([f"{w:.2f}·{n}" for n, w in neg]) if neg else "0"
|
|
19
|
+
base = f"β( [L−R](x̃) = ({pos_str}) − ({neg_str}) > {th[fid]:.2f}; k={k[fid]:.2f}, ν={nu[fid]:.2f} )"
|
|
20
|
+
if model.calibrators is not None:
|
|
21
|
+
base += " where x̃ are calibrated features"
|
|
22
|
+
desc[fid] = base
|
|
23
|
+
return desc
|
nous/explain/fidelity.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import numpy as np
|
|
3
|
+
import torch
|
|
4
|
+
from typing import Optional, Sequence, Dict, Any
|
|
5
|
+
from ..model import NousNet
|
|
6
|
+
from .mse import minimal_sufficient_explanation
|
|
7
|
+
|
|
8
|
+
def explanation_fidelity_metrics(
|
|
9
|
+
model: NousNet, x_sample, feature_names: Sequence[str], class_names: Optional[Sequence[str]] = None,
|
|
10
|
+
margin_tolerance: float = 0.0, prob_tolerance: float = 0.0, pred_tolerance_reg: float = 0.02,
|
|
11
|
+
device=None
|
|
12
|
+
) -> Dict[str, float]:
|
|
13
|
+
"""
|
|
14
|
+
Sufficiency and comprehensiveness metrics using MSE masks.
|
|
15
|
+
"""
|
|
16
|
+
device = device or next(model.parameters()).device
|
|
17
|
+
task = model.config['task_type']
|
|
18
|
+
|
|
19
|
+
base_probas, base_logits, base_internals = model.forward_explain(x_sample, device=device)
|
|
20
|
+
if task == "classification":
|
|
21
|
+
pred_idx = int(np.argmax(base_probas))
|
|
22
|
+
runner_up = int(np.argsort(base_logits)[-2]) if base_logits.size > 1 else pred_idx
|
|
23
|
+
base_margin = float(base_logits[pred_idx] - base_logits[runner_up])
|
|
24
|
+
else:
|
|
25
|
+
base_pred = float(base_logits[0])
|
|
26
|
+
|
|
27
|
+
mse = minimal_sufficient_explanation(
|
|
28
|
+
model, x_sample, feature_names, class_names=class_names,
|
|
29
|
+
margin_tolerance=margin_tolerance, prob_tolerance=prob_tolerance,
|
|
30
|
+
pred_tolerance_reg=pred_tolerance_reg, freeze_non_active=True, device=device
|
|
31
|
+
)
|
|
32
|
+
kept_masks = mse["kept_masks"]
|
|
33
|
+
|
|
34
|
+
prob_s, logit_s, _ = model.forward_explain(x_sample, restrict_masks=kept_masks, device=device)
|
|
35
|
+
inv_masks = [(torch.ones_like(m) - m) for m in kept_masks]
|
|
36
|
+
prob_c, logit_c, _ = model.forward_explain(x_sample, restrict_masks=inv_masks, device=device)
|
|
37
|
+
|
|
38
|
+
if task == "classification":
|
|
39
|
+
runner_up = int(np.argsort(base_logits)[-2]) if base_logits.size > 1 else pred_idx
|
|
40
|
+
margin_s = float(logit_s[pred_idx] - logit_s[runner_up])
|
|
41
|
+
margin_c = float(logit_c[pred_idx] - logit_c[runner_up])
|
|
42
|
+
return {
|
|
43
|
+
"base_margin": base_margin,
|
|
44
|
+
"sufficiency_margin": margin_s,
|
|
45
|
+
"comprehensiveness_margin": margin_c,
|
|
46
|
+
"kept_size": float(sum(int(m.sum().item()) for m in kept_masks))
|
|
47
|
+
}
|
|
48
|
+
else:
|
|
49
|
+
pred_s = float(logit_s[0])
|
|
50
|
+
pred_c = float(logit_c[0])
|
|
51
|
+
return {
|
|
52
|
+
"base_pred": base_pred,
|
|
53
|
+
"sufficiency_pred": pred_s,
|
|
54
|
+
"comprehensiveness_pred": pred_c,
|
|
55
|
+
"kept_size": float(sum(int(m.sum().item()) for m in kept_masks))
|
|
56
|
+
}
|
nous/explain/generate.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import numpy as np
|
|
3
|
+
from typing import Optional, Sequence
|
|
4
|
+
from ..model import NousNet
|
|
5
|
+
from .loo import rule_impact_df
|
|
6
|
+
|
|
7
|
+
def generate_enhanced_explanation(
|
|
8
|
+
model: NousNet, x_sample, y_true, feature_names: Sequence[str], class_names: Optional[Sequence[str]] = None, y_scaler=None,
|
|
9
|
+
loo_mode: str = 'replacement', top_m_rules: Optional[int] = None, use_pre_norm: bool = False,
|
|
10
|
+
use_pruning: bool = False, pruning_threshold: float = 0.0, vs_class_idx: Optional[int] = None
|
|
11
|
+
) -> str:
|
|
12
|
+
"""
|
|
13
|
+
Human-readable explanation text with top rule impacts for a single sample.
|
|
14
|
+
"""
|
|
15
|
+
task = model.config['task_type']
|
|
16
|
+
base_probas, base_logits, _ = model.forward_explain(
|
|
17
|
+
x_sample, apply_pruning=use_pruning, pruning_threshold=pruning_threshold
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
if y_scaler is not None:
|
|
21
|
+
y_true_unscaled = y_scaler.inverse_transform(np.array(y_true).reshape(-1, 1)).item()
|
|
22
|
+
base_pred_unscaled = y_scaler.inverse_transform(base_logits.reshape(-1, 1)).item()
|
|
23
|
+
else:
|
|
24
|
+
y_true_unscaled = y_true
|
|
25
|
+
base_pred_unscaled = base_logits[0] if task == 'regression' else None
|
|
26
|
+
|
|
27
|
+
if task == "classification":
|
|
28
|
+
pred_idx = int(np.argmax(base_probas))
|
|
29
|
+
pred_name = class_names[pred_idx] if class_names else f"Class {pred_idx}"
|
|
30
|
+
conf = float(base_probas[pred_idx])
|
|
31
|
+
true_name = class_names[y_true] if class_names is not None else str(y_true)
|
|
32
|
+
else:
|
|
33
|
+
pred_name = f"Value: {base_pred_unscaled:.3f}"
|
|
34
|
+
true_name = f"{y_true_unscaled:.3f}"
|
|
35
|
+
|
|
36
|
+
lines = []
|
|
37
|
+
model_tag = model.config['rule_selection_method'].upper()
|
|
38
|
+
lines.append(f"MODEL: {model_tag} rules | TASK: {task.upper()}")
|
|
39
|
+
lines.append(f"SAMPLE PREDICTION: {pred_name}")
|
|
40
|
+
if task == "classification":
|
|
41
|
+
lines.append(f" - Confidence: {conf:.3f}")
|
|
42
|
+
lines.append(f" - Ground Truth: {true_name}")
|
|
43
|
+
else:
|
|
44
|
+
lines.append(f" - Ground Truth: {true_name}")
|
|
45
|
+
if use_pruning:
|
|
46
|
+
lines.append(f" - Pruning: |act| >= {pruning_threshold:.4f} (forward uses pruned activations)")
|
|
47
|
+
lines.append("-"*60)
|
|
48
|
+
|
|
49
|
+
imp = rule_impact_df(
|
|
50
|
+
model, x_sample, feature_names, class_names=class_names, vs_class=vs_class_idx,
|
|
51
|
+
loo_mode=loo_mode, top_m_rules=top_m_rules, use_pre_norm=use_pre_norm,
|
|
52
|
+
use_pruning=use_pruning, pruning_threshold=pruning_threshold
|
|
53
|
+
)
|
|
54
|
+
if imp.empty:
|
|
55
|
+
lines.append("No active rules above threshold.")
|
|
56
|
+
return "\n".join(lines)
|
|
57
|
+
|
|
58
|
+
if vs_class_idx is not None and task == "classification":
|
|
59
|
+
lines.append(f"(Contrastive) Why '{pred_name}' vs '{class_names[vs_class_idx]}'?")
|
|
60
|
+
lines.append("CAUSAL RULE IMPACT (Top 5):")
|
|
61
|
+
|
|
62
|
+
for _, row in imp.head(5).iterrows():
|
|
63
|
+
b, r = row['block'], row['rule']
|
|
64
|
+
aggs, facts = row['aggregators'], row['facts']
|
|
65
|
+
badge = ""
|
|
66
|
+
if 'supports_pred' in row:
|
|
67
|
+
badge = " [+]" if row['supports_pred'] else " [-]"
|
|
68
|
+
if 'critical_flip' in row and row['critical_flip']:
|
|
69
|
+
badge += " [CRITICAL]"
|
|
70
|
+
|
|
71
|
+
if task == "classification":
|
|
72
|
+
margin_col = [c for c in imp.columns if c.startswith("Δmargin(")][0]
|
|
73
|
+
s = f"Δmargin={row[margin_col]:+.3f}{badge}"
|
|
74
|
+
else:
|
|
75
|
+
delta_pred = row['Δprediction']
|
|
76
|
+
if y_scaler is not None:
|
|
77
|
+
delta_pred_unscaled = delta_pred * y_scaler.scale_[0]
|
|
78
|
+
s = f"Δprediction={delta_pred_unscaled:+.3f}"
|
|
79
|
+
else:
|
|
80
|
+
s = f"Δprediction={delta_pred:+.3f}"
|
|
81
|
+
lines.append(f" • B{b}/R{r}: {s} | {aggs}\n {facts}")
|
|
82
|
+
|
|
83
|
+
if len(imp) > 5:
|
|
84
|
+
lines.append(f" ... and {len(imp) - 5} more active rules.")
|
|
85
|
+
|
|
86
|
+
return "\n".join(lines)
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import numpy as np
|
|
3
|
+
import pandas as pd
|
|
4
|
+
from typing import Optional, Sequence
|
|
5
|
+
from ..model import NousNet
|
|
6
|
+
from .loo import rule_impact_df
|
|
7
|
+
|
|
8
|
+
def global_rulebook(
|
|
9
|
+
model: NousNet, X, feature_names: Sequence[str], class_names: Optional[Sequence[str]] = None,
|
|
10
|
+
pruning_threshold: Optional[float] = None, use_pruning: bool = False, allow_replacement: bool = True, freeze_non_active: bool = False,
|
|
11
|
+
max_samples: int = 1000, device=None
|
|
12
|
+
) -> pd.DataFrame:
|
|
13
|
+
"""
|
|
14
|
+
Global rulebook aggregation across samples via honest LOO impacts.
|
|
15
|
+
"""
|
|
16
|
+
from tqdm.auto import tqdm # optional dep in extras
|
|
17
|
+
|
|
18
|
+
device = device or next(model.parameters()).device
|
|
19
|
+
n = min(len(X), max_samples)
|
|
20
|
+
totals = {}
|
|
21
|
+
|
|
22
|
+
loo_mode = 'replacement'
|
|
23
|
+
if not allow_replacement and freeze_non_active:
|
|
24
|
+
loo_mode = 'frozen'
|
|
25
|
+
|
|
26
|
+
for i in tqdm(range(n), desc="Analyzing samples"):
|
|
27
|
+
df = rule_impact_df(
|
|
28
|
+
model, X[i], feature_names, class_names=class_names,
|
|
29
|
+
loo_mode=loo_mode, use_pruning=use_pruning, pruning_threshold=(pruning_threshold or 0.0), device=device
|
|
30
|
+
)
|
|
31
|
+
if df.empty:
|
|
32
|
+
continue
|
|
33
|
+
for _, row in df.iterrows():
|
|
34
|
+
key = (int(row["block"]), int(row["rule"]), row["aggregators"])
|
|
35
|
+
d = totals.setdefault(key, {"count": 0, "sum_abs_impact": 0.0, "critical": 0})
|
|
36
|
+
d["count"] += 1
|
|
37
|
+
if "critical_flip" in row and row["critical_flip"]:
|
|
38
|
+
d["critical"] += 1
|
|
39
|
+
metric_col = [c for c in row.index if c.startswith("Δmargin(")]
|
|
40
|
+
val = abs(float(row[metric_col[0]])) if metric_col else abs(float(row.get("Δprediction", 0.0)))
|
|
41
|
+
d["sum_abs_impact"] += val
|
|
42
|
+
|
|
43
|
+
rows = []
|
|
44
|
+
for (b, r, agg), v in totals.items():
|
|
45
|
+
count = v["count"]
|
|
46
|
+
rows.append({
|
|
47
|
+
"block": b, "rule": r, "aggregators": agg,
|
|
48
|
+
"activation_freq": count / n,
|
|
49
|
+
"mean_abs_impact": v["sum_abs_impact"] / count if count else 0.0,
|
|
50
|
+
"critical_rate": v["critical"] / count if count else 0.0
|
|
51
|
+
})
|
|
52
|
+
return pd.DataFrame(rows).sort_values("mean_abs_impact", ascending=False).reset_index(drop=True)
|