nous 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of nous might be problematic. Click here for more details.
- nous/__init__.py +96 -19
- nous/data/__init__.py +4 -0
- nous/data/california.py +32 -0
- nous/data/wine.py +29 -0
- nous/explain/__init__.py +26 -0
- nous/explain/aggregator.py +34 -0
- nous/explain/cf.py +137 -0
- nous/explain/facts_desc.py +23 -0
- nous/explain/fidelity.py +56 -0
- nous/explain/generate.py +86 -0
- nous/explain/global_book.py +52 -0
- nous/explain/loo.py +130 -0
- nous/explain/mse.py +93 -0
- nous/explain/pruning.py +117 -0
- nous/explain/stability.py +42 -0
- nous/explain/traces.py +285 -0
- nous/explain/utils.py +15 -0
- nous/export/__init__.py +13 -0
- nous/export/numpy_infer.py +412 -0
- nous/facts.py +112 -0
- nous/model.py +226 -0
- nous/prototypes.py +43 -0
- nous/rules/__init__.py +11 -0
- nous/rules/blocks.py +63 -0
- nous/rules/fixed.py +26 -0
- nous/rules/softmax.py +93 -0
- nous/rules/sparse.py +142 -0
- nous/training/__init__.py +5 -0
- nous/training/evaluation.py +57 -0
- nous/training/schedulers.py +34 -0
- nous/training/train.py +177 -0
- nous/types.py +4 -0
- nous/utils/__init__.py +3 -0
- nous/utils/metrics.py +2 -0
- nous/utils/seed.py +13 -0
- nous/version.py +1 -0
- nous-0.2.0.dist-info/METADATA +150 -0
- nous-0.2.0.dist-info/RECORD +41 -0
- nous/causal.py +0 -63
- nous/interpret.py +0 -111
- nous/layers.py +0 -117
- nous/models.py +0 -65
- nous-0.1.0.dist-info/METADATA +0 -138
- nous-0.1.0.dist-info/RECORD +0 -10
- {nous-0.1.0.dist-info → nous-0.2.0.dist-info}/WHEEL +0 -0
- {nous-0.1.0.dist-info → nous-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {nous-0.1.0.dist-info → nous-0.2.0.dist-info}/top_level.txt +0 -0
nous/model.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import numpy as np
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
|
7
|
+
|
|
8
|
+
from .facts import BetaFactLayer, PiecewiseLinearCalibrator
|
|
9
|
+
from .prototypes import ScaledPrototypeLayer
|
|
10
|
+
from .rules.blocks import SimpleNousBlock
|
|
11
|
+
from .rules.softmax import SoftmaxRuleLayer
|
|
12
|
+
from .rules.sparse import SparseRuleLayer
|
|
13
|
+
|
|
14
|
+
class NousNet(nn.Module):
|
|
15
|
+
"""
|
|
16
|
+
NousNet: rule-based neural network with honest interpretability for classification and regression.
|
|
17
|
+
"""
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
input_dim: int,
|
|
21
|
+
num_outputs: int,
|
|
22
|
+
task_type: str = "classification", # "classification" | "regression"
|
|
23
|
+
feature_names: Optional[Sequence[str]] = None,
|
|
24
|
+
num_facts: int = 48,
|
|
25
|
+
rules_per_layer: Sequence[int] = (24, 12),
|
|
26
|
+
use_calibrators: bool = False,
|
|
27
|
+
rule_selection_method: str = "fixed", # "fixed" | "softmax" | "sparse"
|
|
28
|
+
use_prototypes: bool = False,
|
|
29
|
+
l0_lambda: float = 1e-3,
|
|
30
|
+
hc_temperature: float = 0.1
|
|
31
|
+
) -> None:
|
|
32
|
+
super().__init__()
|
|
33
|
+
self.config = {
|
|
34
|
+
'input_dim': input_dim, 'num_outputs': num_outputs, 'task_type': task_type,
|
|
35
|
+
'feature_names': list(feature_names) if feature_names is not None else [f"Feature_{i}" for i in range(input_dim)],
|
|
36
|
+
'num_facts': num_facts, 'rules_per_layer': tuple(rules_per_layer),
|
|
37
|
+
'use_calibrators': use_calibrators, 'rule_selection_method': rule_selection_method,
|
|
38
|
+
'use_prototypes': bool(use_prototypes and task_type == "classification"),
|
|
39
|
+
'l0_lambda': l0_lambda, 'hc_temperature': hc_temperature
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
if self.config['use_calibrators']:
|
|
43
|
+
self.calibrators = nn.ModuleList([PiecewiseLinearCalibrator() for _ in range(input_dim)])
|
|
44
|
+
else:
|
|
45
|
+
self.calibrators = None
|
|
46
|
+
|
|
47
|
+
self.fact = BetaFactLayer(input_dim, num_facts)
|
|
48
|
+
|
|
49
|
+
blocks: List[nn.Module] = []
|
|
50
|
+
cur = num_facts
|
|
51
|
+
for r in rules_per_layer:
|
|
52
|
+
if rule_selection_method == 'fixed':
|
|
53
|
+
blocks.append(SimpleNousBlock(cur, r))
|
|
54
|
+
elif rule_selection_method == 'softmax':
|
|
55
|
+
blocks.append(SoftmaxRuleLayer(cur, r))
|
|
56
|
+
elif rule_selection_method == 'sparse':
|
|
57
|
+
blocks.append(SparseRuleLayer(cur, r, l0_lambda=l0_lambda, hc_temperature=hc_temperature))
|
|
58
|
+
else:
|
|
59
|
+
raise ValueError(f"Unknown rule_selection_method: {rule_selection_method}")
|
|
60
|
+
cur = r
|
|
61
|
+
self.blocks = nn.ModuleList(blocks)
|
|
62
|
+
|
|
63
|
+
if self.config['use_prototypes']:
|
|
64
|
+
self.head = ScaledPrototypeLayer(cur, num_prototypes=10, num_classes=num_outputs)
|
|
65
|
+
else:
|
|
66
|
+
self.head = nn.Linear(cur, num_outputs)
|
|
67
|
+
|
|
68
|
+
if self.config['task_type'] == "regression" and self.config['num_outputs'] != 1:
|
|
69
|
+
self.config['num_outputs'] = 1
|
|
70
|
+
if isinstance(self.head, nn.Linear):
|
|
71
|
+
self.head = nn.Linear(cur, 1)
|
|
72
|
+
|
|
73
|
+
def forward(self, x: torch.Tensor, return_internals: bool = False):
|
|
74
|
+
internals: Dict[str, torch.Tensor] = {}
|
|
75
|
+
if self.calibrators is not None:
|
|
76
|
+
x = torch.stack([calib(x[:, i]) for i, calib in enumerate(self.calibrators)], dim=1)
|
|
77
|
+
|
|
78
|
+
h_facts = self.fact(x)
|
|
79
|
+
h = h_facts
|
|
80
|
+
if return_internals:
|
|
81
|
+
internals['facts'] = h_facts.detach()
|
|
82
|
+
|
|
83
|
+
for i, blk in enumerate(self.blocks):
|
|
84
|
+
if return_internals:
|
|
85
|
+
h, details = blk(h, return_details=True)
|
|
86
|
+
internals[f'block_{i}'] = details
|
|
87
|
+
else:
|
|
88
|
+
h = blk(h)
|
|
89
|
+
|
|
90
|
+
logits = self.head(h)
|
|
91
|
+
if self.config['task_type'] == "regression":
|
|
92
|
+
logits = logits.squeeze(-1)
|
|
93
|
+
|
|
94
|
+
if return_internals:
|
|
95
|
+
return logits, internals
|
|
96
|
+
return logits
|
|
97
|
+
|
|
98
|
+
@torch.no_grad()
|
|
99
|
+
def forward_explain(
|
|
100
|
+
self,
|
|
101
|
+
x: Union[np.ndarray, torch.Tensor],
|
|
102
|
+
drop_rule_spec: Optional[Tuple[int, int]] = None,
|
|
103
|
+
restrict_masks: Optional[List[torch.Tensor]] = None,
|
|
104
|
+
apply_pruning: bool = False,
|
|
105
|
+
pruning_threshold: float = 0.0,
|
|
106
|
+
device: Optional[torch.device] = None,
|
|
107
|
+
explain_disable_norm: bool = False,
|
|
108
|
+
explain_exclude_proj: bool = False
|
|
109
|
+
):
|
|
110
|
+
"""
|
|
111
|
+
Honest forward for explanations with interventions/gating recompute.
|
|
112
|
+
|
|
113
|
+
Returns:
|
|
114
|
+
- classification: (probas, logits, internals)
|
|
115
|
+
- regression: (pred, pred, internals)
|
|
116
|
+
"""
|
|
117
|
+
self.eval()
|
|
118
|
+
device = device or next(self.parameters()).device
|
|
119
|
+
|
|
120
|
+
if isinstance(x, np.ndarray):
|
|
121
|
+
x = torch.tensor(x, dtype=torch.float32, device=device).unsqueeze(0)
|
|
122
|
+
elif isinstance(x, torch.Tensor) and x.dim() == 1:
|
|
123
|
+
x = x.unsqueeze(0).to(device)
|
|
124
|
+
else:
|
|
125
|
+
x = x.to(device)
|
|
126
|
+
|
|
127
|
+
if self.calibrators is not None:
|
|
128
|
+
x_cal = torch.stack([calib(x[:, i]) for i, calib in enumerate(self.calibrators)], dim=1)
|
|
129
|
+
else:
|
|
130
|
+
x_cal = x
|
|
131
|
+
|
|
132
|
+
h = self.fact(x_cal)
|
|
133
|
+
internals: Dict[str, torch.Tensor] = {'facts': h.detach()}
|
|
134
|
+
|
|
135
|
+
for i, blk in enumerate(self.blocks):
|
|
136
|
+
drop_idx = None
|
|
137
|
+
if drop_rule_spec is not None and drop_rule_spec[0] == i:
|
|
138
|
+
drop_idx = int(drop_rule_spec[1])
|
|
139
|
+
restrict = None
|
|
140
|
+
if restrict_masks is not None and i < len(restrict_masks) and restrict_masks[i] is not None:
|
|
141
|
+
restrict = restrict_masks[i].to(device)
|
|
142
|
+
prune = pruning_threshold if apply_pruning else None
|
|
143
|
+
|
|
144
|
+
h, details = blk(
|
|
145
|
+
h,
|
|
146
|
+
return_details=True,
|
|
147
|
+
drop_rule_idx=drop_idx,
|
|
148
|
+
restrict_mask=restrict,
|
|
149
|
+
prune_below=prune,
|
|
150
|
+
explain_disable_norm=explain_disable_norm,
|
|
151
|
+
explain_exclude_proj=explain_exclude_proj
|
|
152
|
+
)
|
|
153
|
+
internals[f'block_{i}'] = details
|
|
154
|
+
|
|
155
|
+
logits = self.head(h)
|
|
156
|
+
if self.config['task_type'] == "classification":
|
|
157
|
+
probas = F.softmax(logits, dim=-1)
|
|
158
|
+
return probas.squeeze(0).cpu().numpy(), logits.squeeze(0).cpu().numpy(), internals
|
|
159
|
+
else:
|
|
160
|
+
pred = logits.squeeze(-1).squeeze(0).cpu().numpy()
|
|
161
|
+
return np.array([pred]), np.array([pred]), internals
|
|
162
|
+
|
|
163
|
+
def compute_total_l0_loss(self) -> torch.Tensor:
|
|
164
|
+
if self.config['rule_selection_method'] != 'sparse':
|
|
165
|
+
return torch.tensor(0.0, device=next(self.parameters()).device)
|
|
166
|
+
total_l0 = 0.0
|
|
167
|
+
for blk in self.blocks:
|
|
168
|
+
if hasattr(blk, 'compute_l0_loss'):
|
|
169
|
+
total_l0 += blk.compute_l0_loss()
|
|
170
|
+
return total_l0
|
|
171
|
+
|
|
172
|
+
def model_summary(self) -> Dict[str, object]:
|
|
173
|
+
summary = {
|
|
174
|
+
"Task": self.config['task_type'],
|
|
175
|
+
"Rule Selection": self.config['rule_selection_method'],
|
|
176
|
+
"Use Calibrators": self.config['use_calibrators'],
|
|
177
|
+
"Use Prototypes": self.config['use_prototypes'],
|
|
178
|
+
"Num Facts": self.config['num_facts'],
|
|
179
|
+
"Rules per Layer": self.config['rules_per_layer'],
|
|
180
|
+
"Total Parameters": sum(p.numel() for p in self.parameters())
|
|
181
|
+
}
|
|
182
|
+
if self.config['rule_selection_method'] == 'sparse':
|
|
183
|
+
summary["L0 Lambda"] = self.config['l0_lambda']
|
|
184
|
+
return summary
|
|
185
|
+
|
|
186
|
+
@torch.no_grad()
|
|
187
|
+
def encode(
|
|
188
|
+
self,
|
|
189
|
+
X: Union[np.ndarray, torch.Tensor, Sequence[Sequence[float]]],
|
|
190
|
+
device: Optional[torch.device] = None,
|
|
191
|
+
batch_size: int = 2048,
|
|
192
|
+
explain_disable_norm: bool = False,
|
|
193
|
+
explain_exclude_proj: bool = False
|
|
194
|
+
) -> torch.Tensor:
|
|
195
|
+
"""
|
|
196
|
+
Return H [N, D_last] — representations at the head input.
|
|
197
|
+
|
|
198
|
+
Flags allow clean representations (without LayerNorm / without residual projection).
|
|
199
|
+
"""
|
|
200
|
+
self.eval()
|
|
201
|
+
device = device or next(self.parameters()).device
|
|
202
|
+
|
|
203
|
+
if isinstance(X, np.ndarray):
|
|
204
|
+
X_tensor = torch.tensor(X, dtype=torch.float32)
|
|
205
|
+
elif isinstance(X, torch.Tensor):
|
|
206
|
+
X_tensor = X.detach().cpu().float()
|
|
207
|
+
else:
|
|
208
|
+
X_tensor = torch.tensor(np.asarray(X), dtype=torch.float32)
|
|
209
|
+
|
|
210
|
+
H_list = []
|
|
211
|
+
for i in range(0, len(X_tensor), batch_size):
|
|
212
|
+
xb = X_tensor[i:i+batch_size].to(device)
|
|
213
|
+
|
|
214
|
+
if self.calibrators is not None:
|
|
215
|
+
xb_cal = torch.stack([calib(xb[:, j]) for j, calib in enumerate(self.calibrators)], dim=1)
|
|
216
|
+
else:
|
|
217
|
+
xb_cal = xb
|
|
218
|
+
|
|
219
|
+
h = self.fact(xb_cal)
|
|
220
|
+
for blk in self.blocks:
|
|
221
|
+
h, _ = blk(h, return_details=True,
|
|
222
|
+
explain_disable_norm=explain_disable_norm,
|
|
223
|
+
explain_exclude_proj=explain_exclude_proj)
|
|
224
|
+
H_list.append(h.detach().cpu())
|
|
225
|
+
H = torch.cat(H_list, dim=0)
|
|
226
|
+
return H
|
nous/prototypes.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from typing import Dict
|
|
6
|
+
|
|
7
|
+
class ScaledPrototypeLayer(nn.Module):
|
|
8
|
+
"""
|
|
9
|
+
Prototype-based head for classification with temperature scaling on distances.
|
|
10
|
+
Uses L2-normalized features/prototypes and exp(-tau * distance).
|
|
11
|
+
"""
|
|
12
|
+
def __init__(self, input_dim: int, num_prototypes: int, num_classes: int) -> None:
|
|
13
|
+
super().__init__()
|
|
14
|
+
self.num_prototypes = num_prototypes
|
|
15
|
+
self.num_classes = num_classes
|
|
16
|
+
self.temperature = nn.Parameter(torch.tensor(1.0))
|
|
17
|
+
self.prototypes = nn.Parameter(torch.randn(num_prototypes, input_dim) * 0.1)
|
|
18
|
+
self.prototype_class = nn.Parameter(torch.randn(num_prototypes, num_classes))
|
|
19
|
+
|
|
20
|
+
def forward(self, h: torch.Tensor) -> torch.Tensor:
|
|
21
|
+
h_norm = F.normalize(h, p=2, dim=1)
|
|
22
|
+
prototypes_norm = F.normalize(self.prototypes, p=2, dim=1)
|
|
23
|
+
distances = torch.cdist(h_norm, prototypes_norm)
|
|
24
|
+
activations = torch.exp(-F.softplus(self.temperature) * distances)
|
|
25
|
+
class_logits = torch.matmul(activations, self.prototype_class)
|
|
26
|
+
return class_logits
|
|
27
|
+
|
|
28
|
+
@torch.no_grad()
|
|
29
|
+
def get_params(self) -> Dict[str, torch.Tensor]:
|
|
30
|
+
P = self.prototypes.detach()
|
|
31
|
+
Pn = F.normalize(P, p=2, dim=1)
|
|
32
|
+
W = self.prototype_class.detach()
|
|
33
|
+
Wsm = F.softmax(W, dim=1)
|
|
34
|
+
tau = F.softplus(self.temperature).detach()
|
|
35
|
+
return dict(prototypes=P, prototypes_norm=Pn, class_weights=W, class_probs=Wsm, temperature=tau)
|
|
36
|
+
|
|
37
|
+
@torch.no_grad()
|
|
38
|
+
def compute_dist_act(self, h: torch.Tensor):
|
|
39
|
+
h_norm = F.normalize(h, p=2, dim=1)
|
|
40
|
+
Pn = F.normalize(self.prototypes, p=2, dim=1)
|
|
41
|
+
d = torch.cdist(h_norm, Pn)
|
|
42
|
+
act = torch.exp(-F.softplus(self.temperature) * d)
|
|
43
|
+
return d, act
|
nous/rules/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from .fixed import FixedPairRuleLayer
|
|
2
|
+
from .softmax import SoftmaxRuleLayer
|
|
3
|
+
from .sparse import SparseRuleLayer
|
|
4
|
+
from .blocks import SimpleNousBlock
|
|
5
|
+
|
|
6
|
+
__all__ = [
|
|
7
|
+
"FixedPairRuleLayer",
|
|
8
|
+
"SoftmaxRuleLayer",
|
|
9
|
+
"SparseRuleLayer",
|
|
10
|
+
"SimpleNousBlock",
|
|
11
|
+
]
|
nous/rules/blocks.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
from .fixed import FixedPairRuleLayer
|
|
5
|
+
|
|
6
|
+
class SimpleNousBlock(nn.Module):
|
|
7
|
+
"""
|
|
8
|
+
Residual block over FixedPairRuleLayer with honest interventions and explain flags.
|
|
9
|
+
|
|
10
|
+
Options:
|
|
11
|
+
- drop_rule_idx: zero a single rule pre-gating
|
|
12
|
+
- restrict_mask: 0/1 mask over rules (frozen active set)
|
|
13
|
+
- prune_below: threshold prune post-gating
|
|
14
|
+
- explain_disable_norm / explain_exclude_proj for clean contributions
|
|
15
|
+
"""
|
|
16
|
+
def __init__(self, in_dim: int, n_rules: int) -> None:
|
|
17
|
+
super().__init__()
|
|
18
|
+
self.rule = FixedPairRuleLayer(in_dim, n_rules)
|
|
19
|
+
self.proj = nn.Identity() if in_dim == n_rules else nn.Linear(in_dim, n_rules, bias=False)
|
|
20
|
+
self.norm = nn.LayerNorm(n_rules)
|
|
21
|
+
|
|
22
|
+
def forward(
|
|
23
|
+
self,
|
|
24
|
+
x: torch.Tensor,
|
|
25
|
+
return_details: bool = False,
|
|
26
|
+
drop_rule_idx: int | None = None,
|
|
27
|
+
restrict_mask: torch.Tensor | None = None,
|
|
28
|
+
prune_below: float | None = None,
|
|
29
|
+
explain_disable_norm: bool = False,
|
|
30
|
+
explain_exclude_proj: bool = False
|
|
31
|
+
):
|
|
32
|
+
rule_activations = self.rule(x)
|
|
33
|
+
|
|
34
|
+
if restrict_mask is not None:
|
|
35
|
+
rule_activations = rule_activations * restrict_mask
|
|
36
|
+
if drop_rule_idx is not None:
|
|
37
|
+
rule_activations[:, drop_rule_idx] = 0.0
|
|
38
|
+
|
|
39
|
+
gated_activations = rule_activations
|
|
40
|
+
if prune_below is not None:
|
|
41
|
+
mask = (gated_activations.abs() >= prune_below).float()
|
|
42
|
+
gated_activations = gated_activations * mask
|
|
43
|
+
|
|
44
|
+
proj_contrib = torch.zeros_like(gated_activations)
|
|
45
|
+
if not isinstance(self.proj, nn.Identity):
|
|
46
|
+
proj_contrib = self.proj(x)
|
|
47
|
+
|
|
48
|
+
pre_sum = (0.0 * proj_contrib) + gated_activations if explain_exclude_proj else (proj_contrib + gated_activations)
|
|
49
|
+
output = pre_sum if explain_disable_norm else self.norm(pre_sum)
|
|
50
|
+
|
|
51
|
+
if return_details:
|
|
52
|
+
details = {
|
|
53
|
+
"pre_rule_activations": rule_activations.detach(),
|
|
54
|
+
"gated_activations": gated_activations.detach(),
|
|
55
|
+
"gate_mask": torch.ones_like(gated_activations, dtype=torch.float32),
|
|
56
|
+
"aggregator_weights": None,
|
|
57
|
+
"selected_indices": torch.arange(gated_activations.shape[1], device=x.device).unsqueeze(0).repeat(x.size(0), 1),
|
|
58
|
+
"facts_used": self.rule.idx.detach(),
|
|
59
|
+
"pre_norm_sum": pre_sum.detach(),
|
|
60
|
+
"proj_contrib": proj_contrib.detach()
|
|
61
|
+
}
|
|
62
|
+
return output, details
|
|
63
|
+
return output
|
nous/rules/fixed.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
|
|
5
|
+
class FixedPairRuleLayer(nn.Module):
|
|
6
|
+
"""
|
|
7
|
+
Fixed random fact pairs with AND logic. Pair indices are distinct per rule.
|
|
8
|
+
Output: rule_strength * (f_i * f_j)
|
|
9
|
+
"""
|
|
10
|
+
def __init__(self, input_dim: int, num_rules: int) -> None:
|
|
11
|
+
super().__init__()
|
|
12
|
+
idx1 = torch.randint(0, input_dim, (num_rules,))
|
|
13
|
+
offset = torch.randint(1, input_dim, (num_rules,))
|
|
14
|
+
idx2 = (idx1 + offset) % input_dim
|
|
15
|
+
idx = torch.stack([idx1, idx2], dim=1)
|
|
16
|
+
self.register_buffer('idx', idx)
|
|
17
|
+
self.weight = nn.Parameter(torch.ones(num_rules))
|
|
18
|
+
|
|
19
|
+
def forward(self, facts: torch.Tensor) -> torch.Tensor:
|
|
20
|
+
f1, f2 = facts[:, self.idx[:,0]], facts[:, self.idx[:,1]]
|
|
21
|
+
rule_strength = torch.sigmoid(self.weight)
|
|
22
|
+
return rule_strength * (f1 * f2)
|
|
23
|
+
|
|
24
|
+
@torch.no_grad()
|
|
25
|
+
def get_rules(self):
|
|
26
|
+
return self.idx.cpu().numpy(), torch.sigmoid(self.weight).cpu().numpy()
|
nous/rules/softmax.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
|
|
6
|
+
class SoftmaxRuleLayer(nn.Module):
|
|
7
|
+
"""
|
|
8
|
+
Learnable fact connections via softmax gating with per-rule top-k and aggregator mixing.
|
|
9
|
+
Aggregators: AND, OR, k-of-n. Honest interventions supported.
|
|
10
|
+
"""
|
|
11
|
+
def __init__(self, input_dim: int, num_rules: int, top_k_facts: int = 2, top_k_rules: int = 8) -> None:
|
|
12
|
+
super().__init__()
|
|
13
|
+
self.input_dim = input_dim
|
|
14
|
+
self.num_rules = num_rules
|
|
15
|
+
self.top_k_facts = top_k_facts
|
|
16
|
+
self.top_k_rules = top_k_rules
|
|
17
|
+
|
|
18
|
+
self.fact_logits = nn.Parameter(torch.randn(num_rules, input_dim) * 0.01)
|
|
19
|
+
self.num_aggregators = 3
|
|
20
|
+
self.aggregator_logits = nn.Parameter(torch.zeros(num_rules, self.num_aggregators))
|
|
21
|
+
self.rule_strength_raw = nn.Parameter(torch.zeros(num_rules))
|
|
22
|
+
self.proj = nn.Linear(input_dim, num_rules, bias=False) if input_dim != num_rules else nn.Identity()
|
|
23
|
+
self.norm = nn.LayerNorm(num_rules)
|
|
24
|
+
|
|
25
|
+
def forward(
|
|
26
|
+
self,
|
|
27
|
+
facts: torch.Tensor,
|
|
28
|
+
return_details: bool = False,
|
|
29
|
+
drop_rule_idx: int | None = None,
|
|
30
|
+
restrict_mask: torch.Tensor | None = None,
|
|
31
|
+
prune_below: float | None = None,
|
|
32
|
+
explain_disable_norm: bool = False,
|
|
33
|
+
explain_exclude_proj: bool = False
|
|
34
|
+
):
|
|
35
|
+
fact_logits_soft = F.softmax(self.fact_logits, dim=1)
|
|
36
|
+
_, topk_fact_idx = torch.topk(fact_logits_soft, k=min(self.top_k_facts, self.input_dim), dim=1)
|
|
37
|
+
mask = torch.zeros_like(fact_logits_soft)
|
|
38
|
+
mask.scatter_(1, topk_fact_idx, 1.0)
|
|
39
|
+
|
|
40
|
+
facts_expanded = facts.unsqueeze(1)
|
|
41
|
+
mask_expanded = mask.unsqueeze(0)
|
|
42
|
+
selected_facts = facts_expanded * mask_expanded
|
|
43
|
+
|
|
44
|
+
and_agg = torch.prod(selected_facts + (1 - mask_expanded), dim=2)
|
|
45
|
+
or_agg = 1 - torch.prod((1 - selected_facts) * mask_expanded + (1 - mask_expanded), dim=2)
|
|
46
|
+
k_of_n_agg = torch.sum(selected_facts, dim=2) / (mask_expanded.sum(dim=2) + 1e-8)
|
|
47
|
+
|
|
48
|
+
agg_weights = F.softmax(self.aggregator_logits, dim=1)
|
|
49
|
+
aggregators = torch.stack([and_agg, or_agg, k_of_n_agg], dim=2)
|
|
50
|
+
mixed_agg = torch.sum(aggregators * agg_weights.unsqueeze(0), dim=2)
|
|
51
|
+
|
|
52
|
+
rule_strength = torch.sigmoid(self.rule_strength_raw)
|
|
53
|
+
rule_activations = mixed_agg * rule_strength.unsqueeze(0)
|
|
54
|
+
|
|
55
|
+
pre_for_topk = rule_activations.clone()
|
|
56
|
+
if restrict_mask is not None:
|
|
57
|
+
pre_for_topk = pre_for_topk + (restrict_mask - 1) * 1e9
|
|
58
|
+
if drop_rule_idx is not None:
|
|
59
|
+
pre_for_topk[:, drop_rule_idx] = -1e9
|
|
60
|
+
|
|
61
|
+
k = min(self.top_k_rules, self.num_rules)
|
|
62
|
+
_, topk_rule_idx = torch.topk(pre_for_topk, k=k, dim=1)
|
|
63
|
+
gate_mask = torch.zeros_like(rule_activations)
|
|
64
|
+
gate_mask.scatter_(1, topk_rule_idx, 1.0)
|
|
65
|
+
|
|
66
|
+
if restrict_mask is not None:
|
|
67
|
+
gate_mask = gate_mask * restrict_mask.unsqueeze(0).to(gate_mask.dtype)
|
|
68
|
+
if drop_rule_idx is not None:
|
|
69
|
+
gate_mask[:, drop_rule_idx] = 0.0
|
|
70
|
+
|
|
71
|
+
gated_activations = rule_activations * gate_mask
|
|
72
|
+
if prune_below is not None:
|
|
73
|
+
keep = (gated_activations.abs() >= prune_below).float()
|
|
74
|
+
gated_activations = gated_activations * keep
|
|
75
|
+
gate_mask = gate_mask * keep
|
|
76
|
+
|
|
77
|
+
proj_contrib = self.proj(facts) if not isinstance(self.proj, nn.Identity) else facts
|
|
78
|
+
pre_sum = gated_activations if explain_exclude_proj else (proj_contrib + gated_activations)
|
|
79
|
+
output = pre_sum if explain_disable_norm else self.norm(pre_sum)
|
|
80
|
+
|
|
81
|
+
if return_details:
|
|
82
|
+
details = {
|
|
83
|
+
"pre_rule_activations": rule_activations.detach(),
|
|
84
|
+
"gated_activations": gated_activations.detach(),
|
|
85
|
+
"gate_mask": gate_mask.detach(),
|
|
86
|
+
"aggregator_weights": agg_weights.detach(),
|
|
87
|
+
"selected_indices": topk_rule_idx.detach(),
|
|
88
|
+
"facts_used": topk_fact_idx.detach(),
|
|
89
|
+
"pre_norm_sum": pre_sum.detach(),
|
|
90
|
+
"proj_contrib": proj_contrib.detach()
|
|
91
|
+
}
|
|
92
|
+
return output, details
|
|
93
|
+
return output
|
nous/rules/sparse.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
|
|
6
|
+
class HardConcretePerConnection(nn.Module):
|
|
7
|
+
"""
|
|
8
|
+
Per-connection Hard-Concrete gate with safer sampling (clamped u and log_sigma).
|
|
9
|
+
"""
|
|
10
|
+
def __init__(self, shape, init_mean=0.0, init_std=1.0, temperature=0.1, stretch_limits=(-0.1, 1.1)):
|
|
11
|
+
super().__init__()
|
|
12
|
+
self.shape = shape
|
|
13
|
+
self.beta = nn.Parameter(torch.ones(shape) * init_mean)
|
|
14
|
+
self.log_sigma = nn.Parameter(torch.ones(shape) * init_std)
|
|
15
|
+
self.temperature = temperature
|
|
16
|
+
self.l, self.r = stretch_limits
|
|
17
|
+
|
|
18
|
+
def forward(self, training: bool | None = None) -> torch.Tensor:
|
|
19
|
+
if training is None:
|
|
20
|
+
training = self.training
|
|
21
|
+
if training:
|
|
22
|
+
u = torch.rand(self.shape, device=self.beta.device).clamp(1e-6, 1 - 1e-6)
|
|
23
|
+
sigma = torch.exp(torch.clamp(self.log_sigma, -5.0, 5.0))
|
|
24
|
+
pre = (torch.log(u) - torch.log(1 - u) + self.beta) / (sigma + 1e-8)
|
|
25
|
+
s = torch.sigmoid(pre)
|
|
26
|
+
s = s * (self.r - self.l) + self.l
|
|
27
|
+
left = torch.sigmoid((s - self.l) / self.temperature)
|
|
28
|
+
right = torch.sigmoid((s - self.r) / self.temperature)
|
|
29
|
+
mask = (left - right).clamp(0.0, 1.0)
|
|
30
|
+
else:
|
|
31
|
+
mask = (torch.sigmoid(self.beta) > 0.5).float()
|
|
32
|
+
return mask
|
|
33
|
+
|
|
34
|
+
def get_proba(self) -> torch.Tensor:
|
|
35
|
+
l_tensor = torch.tensor(self.l, device=self.beta.device)
|
|
36
|
+
r_tensor = torch.tensor(self.r, device=self.beta.device)
|
|
37
|
+
return torch.sigmoid(torch.clamp(self.beta, -10.0, 10.0) - self.temperature * torch.log(-l_tensor / r_tensor + 1e-8))
|
|
38
|
+
|
|
39
|
+
def l0_penalty(self) -> torch.Tensor:
|
|
40
|
+
return self.get_proba().sum()
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class SparseRuleLayer(nn.Module):
|
|
44
|
+
"""
|
|
45
|
+
Hard-Concrete sparse connections + aggregator mixing + honest interventions.
|
|
46
|
+
Aggregators: AND, OR, k-of-n, NOT.
|
|
47
|
+
"""
|
|
48
|
+
def __init__(self, input_dim: int, num_rules: int, top_k_facts: int = 2, top_k_rules: int = 8, l0_lambda: float = 1e-3, hc_temperature: float = 0.1):
|
|
49
|
+
super().__init__()
|
|
50
|
+
self.input_dim = input_dim
|
|
51
|
+
self.num_rules = num_rules
|
|
52
|
+
self.top_k_facts = top_k_facts
|
|
53
|
+
self.top_k_rules = top_k_rules
|
|
54
|
+
self.l0_lambda = l0_lambda
|
|
55
|
+
|
|
56
|
+
self.hard_concrete = HardConcretePerConnection(
|
|
57
|
+
shape=(num_rules, input_dim),
|
|
58
|
+
init_mean=0.0,
|
|
59
|
+
init_std=1.0,
|
|
60
|
+
temperature=hc_temperature,
|
|
61
|
+
stretch_limits=(-0.1, 1.1)
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
self.num_aggregators = 4
|
|
65
|
+
self.aggregator_logits = nn.Parameter(torch.zeros(num_rules, self.num_aggregators))
|
|
66
|
+
self.rule_strength_raw = nn.Parameter(torch.zeros(num_rules))
|
|
67
|
+
self.proj = nn.Linear(input_dim, num_rules, bias=False) if input_dim != num_rules else nn.Identity()
|
|
68
|
+
self.norm = nn.LayerNorm(num_rules)
|
|
69
|
+
|
|
70
|
+
def forward(
|
|
71
|
+
self,
|
|
72
|
+
facts: torch.Tensor,
|
|
73
|
+
return_details: bool = False,
|
|
74
|
+
drop_rule_idx: int | None = None,
|
|
75
|
+
restrict_mask: torch.Tensor | None = None,
|
|
76
|
+
prune_below: float | None = None,
|
|
77
|
+
explain_disable_norm: bool = False,
|
|
78
|
+
explain_exclude_proj: bool = False
|
|
79
|
+
):
|
|
80
|
+
mask = self.hard_concrete(training=self.training) # [R, F]
|
|
81
|
+
facts_expanded = facts.unsqueeze(1)
|
|
82
|
+
mask_expanded = mask.unsqueeze(0)
|
|
83
|
+
selected_facts = facts_expanded * mask_expanded
|
|
84
|
+
|
|
85
|
+
and_agg = torch.prod(selected_facts + (1 - mask_expanded), dim=2)
|
|
86
|
+
or_agg = 1 - torch.prod((1 - selected_facts) * mask_expanded + (1 - mask_expanded), dim=2)
|
|
87
|
+
selected_count = torch.sum(mask_expanded, dim=2) + 1e-8
|
|
88
|
+
k_of_n_agg = torch.sum(selected_facts, dim=2) / selected_count
|
|
89
|
+
not_agg = 1 - k_of_n_agg
|
|
90
|
+
|
|
91
|
+
agg_weights = F.softmax(self.aggregator_logits, dim=1) # [R, 4]
|
|
92
|
+
aggregators = torch.stack([and_agg, or_agg, k_of_n_agg, not_agg], dim=2)
|
|
93
|
+
mixed_agg = torch.sum(aggregators * agg_weights.unsqueeze(0), dim=2) # [B, R]
|
|
94
|
+
|
|
95
|
+
rule_strength = torch.sigmoid(self.rule_strength_raw) # [R]
|
|
96
|
+
rule_activations = mixed_agg * rule_strength.unsqueeze(0) # [B, R]
|
|
97
|
+
|
|
98
|
+
pre_for_topk = rule_activations.clone()
|
|
99
|
+
if restrict_mask is not None:
|
|
100
|
+
pre_for_topk = pre_for_topk + (restrict_mask - 1) * 1e9
|
|
101
|
+
if drop_rule_idx is not None:
|
|
102
|
+
pre_for_topk[:, drop_rule_idx] = -1e9
|
|
103
|
+
|
|
104
|
+
k = min(self.top_k_rules, self.num_rules)
|
|
105
|
+
_, topk_rule_idx = torch.topk(pre_for_topk, k=k, dim=1)
|
|
106
|
+
gate_mask = torch.zeros_like(rule_activations)
|
|
107
|
+
gate_mask.scatter_(1, topk_rule_idx, 1.0)
|
|
108
|
+
|
|
109
|
+
if restrict_mask is not None:
|
|
110
|
+
gate_mask = gate_mask * restrict_mask.unsqueeze(0).to(gate_mask.dtype)
|
|
111
|
+
if drop_rule_idx is not None:
|
|
112
|
+
gate_mask[:, drop_rule_idx] = 0.0
|
|
113
|
+
|
|
114
|
+
gated_activations = rule_activations * gate_mask
|
|
115
|
+
if prune_below is not None:
|
|
116
|
+
keep = (gated_activations.abs() >= prune_below).float()
|
|
117
|
+
gated_activations = gated_activations * keep
|
|
118
|
+
gate_mask = gate_mask * keep
|
|
119
|
+
|
|
120
|
+
proj_contrib = self.proj(facts) if not isinstance(self.proj, nn.Identity) else facts
|
|
121
|
+
pre_sum = gated_activations if explain_exclude_proj else (proj_contrib + gated_activations)
|
|
122
|
+
output = pre_sum if explain_disable_norm else self.norm(pre_sum)
|
|
123
|
+
|
|
124
|
+
if return_details:
|
|
125
|
+
with torch.no_grad():
|
|
126
|
+
_, topk_fact_idx = torch.topk(mask, k=min(self.top_k_facts, self.input_dim), dim=1)
|
|
127
|
+
details = {
|
|
128
|
+
"pre_rule_activations": rule_activations.detach(),
|
|
129
|
+
"gated_activations": gated_activations.detach(),
|
|
130
|
+
"gate_mask": gate_mask.detach(),
|
|
131
|
+
"aggregator_weights": agg_weights.detach(),
|
|
132
|
+
"selected_indices": topk_rule_idx.detach(),
|
|
133
|
+
"facts_used": topk_fact_idx.detach(),
|
|
134
|
+
"pre_norm_sum": pre_sum.detach(),
|
|
135
|
+
"proj_contrib": proj_contrib.detach()
|
|
136
|
+
}
|
|
137
|
+
return output, details
|
|
138
|
+
return output
|
|
139
|
+
|
|
140
|
+
def compute_l0_loss(self) -> torch.Tensor:
|
|
141
|
+
num_conn = float(self.num_rules * self.input_dim)
|
|
142
|
+
return (self.l0_lambda * self.hard_concrete.l0_penalty()) / max(1.0, num_conn)
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import numpy as np
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn.functional as F
|
|
5
|
+
from torch.utils.data import DataLoader
|
|
6
|
+
from sklearn.metrics import (
|
|
7
|
+
roc_auc_score, accuracy_score, mean_absolute_error,
|
|
8
|
+
mean_squared_error, r2_score
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
def evaluate_classification(model, test_loader: DataLoader, device, class_names=None):
|
|
12
|
+
"""
|
|
13
|
+
Evaluate classification metrics on a dataloader: accuracy and one-vs-rest AUC.
|
|
14
|
+
"""
|
|
15
|
+
model.eval()
|
|
16
|
+
all_probs, all_labels = [], []
|
|
17
|
+
with torch.no_grad():
|
|
18
|
+
for X_batch, y_batch in test_loader:
|
|
19
|
+
X_batch = X_batch.to(device)
|
|
20
|
+
probs = F.softmax(model(X_batch), dim=1).cpu().numpy()
|
|
21
|
+
all_probs.extend(probs)
|
|
22
|
+
all_labels.extend(y_batch.numpy())
|
|
23
|
+
all_probs = np.array(all_probs)
|
|
24
|
+
all_preds = np.argmax(all_probs, axis=1)
|
|
25
|
+
acc = accuracy_score(all_labels, all_preds)
|
|
26
|
+
try:
|
|
27
|
+
auc = roc_auc_score(all_labels, all_probs, multi_class='ovr')
|
|
28
|
+
except ValueError:
|
|
29
|
+
auc = 0.5
|
|
30
|
+
return acc, auc, all_probs, all_labels
|
|
31
|
+
|
|
32
|
+
def evaluate_regression(model, test_loader: DataLoader, device, y_scaler=None):
|
|
33
|
+
"""
|
|
34
|
+
Evaluate regression: RMSE, MAE, R^2. Returns unscaled predictions if y_scaler is passed.
|
|
35
|
+
"""
|
|
36
|
+
model.eval()
|
|
37
|
+
preds, labels = [], []
|
|
38
|
+
with torch.no_grad():
|
|
39
|
+
for X_batch, y_batch in test_loader:
|
|
40
|
+
X_batch = X_batch.to(device)
|
|
41
|
+
pred = model(X_batch).cpu().numpy().ravel()
|
|
42
|
+
preds.extend(pred)
|
|
43
|
+
labels.extend(y_batch.numpy().ravel())
|
|
44
|
+
|
|
45
|
+
preds = np.array(preds)
|
|
46
|
+
labels = np.array(labels)
|
|
47
|
+
|
|
48
|
+
if y_scaler is not None:
|
|
49
|
+
preds_u = y_scaler.inverse_transform(preds.reshape(-1,1)).ravel()
|
|
50
|
+
labels_u = y_scaler.inverse_transform(labels.reshape(-1,1)).ravel()
|
|
51
|
+
else:
|
|
52
|
+
preds_u, labels_u = preds, labels
|
|
53
|
+
|
|
54
|
+
rmse = np.sqrt(mean_squared_error(labels_u, preds_u))
|
|
55
|
+
mae = mean_absolute_error(labels_u, preds_u)
|
|
56
|
+
r2 = r2_score(labels_u, preds_u)
|
|
57
|
+
return rmse, mae, r2, preds_u, labels_u
|